Commit 411dcb76 authored by rusty1s's avatar rusty1s
Browse files

select methods

parent 7f7036cd
...@@ -23,3 +23,5 @@ __all__ = [ ...@@ -23,3 +23,5 @@ __all__ = [
from .storage import SparseStorage from .storage import SparseStorage
from .tensor import SparseTensor from .tensor import SparseTensor
from .transpose import t from .transpose import t
from .narrow import narrow
from .select import select
import torch import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def narrow(src, dim, start, length): @torch.jit.script
def narrow(src: SparseTensor, dim: int, start: int, length: int):
dim = src.dim() + dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
start = src.size(dim) + start if start < 0 else start start = src.size(dim) + start if start < 0 else start
if dim == 0: if dim == 0:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
# rowptr = src.storage.rowptr
# Maintain `rowcount`...
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start=start, length=length)
rowptr = rowptr.narrow(0, start=start, length=length + 1) rowptr = rowptr.narrow(0, start=start, length=length + 1)
row_start = rowptr[0] row_start = rowptr[0]
...@@ -22,46 +19,60 @@ def narrow(src, dim, start, length): ...@@ -22,46 +19,60 @@ def narrow(src, dim, start, length):
row = src.storage._row row = src.storage._row
if row is not None: if row is not None:
row = row.narrow(0, row_start, row_length) - start row = row.narrow(0, row_start, row_length) - start
col = col.narrow(0, row_start, row_length) col = col.narrow(0, row_start, row_length)
if src.has_value(): if value is not None:
value = value.narrow(0, row_start, row_length) value = value.narrow(0, row_start, row_length)
sparse_size = torch.Size([length, src.sparse_size(1)]) sparse_sizes = torch.Size([length, src.sparse_size(1)])
storage = src.storage.__class__(row=row, rowptr=rowptr, col=col, rowcount = src.storage._rowcount
value=value, sparse_size=sparse_size, if rowcount is not None:
rowcount=rowcount, is_sorted=True) rowcount = rowcount.narrow(0, start=start, length=length)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1: elif dim == 1:
# This is faster than accessing `csc()` contrary to the `dim=0` case. # This is faster than accessing `csc()` contrary to the `dim=0` case.
row, col, value = src.coo() row, col, value = src.coo()
mask = (col >= start) & (col < start + length) mask = (col >= start) & (col < start + length)
row, col = row[mask], col[mask] - start row = row[mask]
col = col[mask] - start
# Maintain `colcount`... if value is not None:
colcount = src.storage._colcount value = value[mask]
if colcount is not None:
colcount = colcount.narrow(0, start=start, length=length) sparse_sizes = torch.Size([src.sparse_size(0), length])
# Maintain `colptr`...
colptr = src.storage._colptr colptr = src.storage._colptr
if colptr is not None: if colptr is not None:
colptr = colptr.narrow(0, start=start, length=length + 1) colptr = colptr.narrow(0, start=start, length=length + 1)
colptr = colptr - colptr[0] colptr = colptr - colptr[0]
if src.has_value(): colcount = src.storage._colcount
value = value[mask] if colcount is not None:
colcount = colcount.narrow(0, start=start, length=length)
sparse_size = torch.Size([src.sparse_size(0), length])
storage = src.storage.__class__(row=row, col=col, value=value, storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_size=sparse_size, colptr=colptr, sparse_sizes=sparse_sizes, rowcount=None,
colcount=colcount, is_sorted=True) colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
else: else:
storage = src.storage.apply_value( value = src.storage.value()
lambda x: x.narrow(dim - 1, start, length)) if value is not None:
return src.set_value(value.narrow(dim - 1, start, length),
layout='coo')
else:
raise ValueError
return src.from_storage(storage) SparseTensor.narrow = lambda self, dim, start, length: narrow(
self, dim, start, length)
def select(src, dim, idx): import torch
return src.narrow(dim, start=idx, length=1) from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow
@torch.jit.script
def select(src: SparseTensor, dim: int, idx: int):
return narrow(src, dim, start=idx, length=1)
SparseTensor.select = lambda self, dim, idx: select(self, dim, idx)
...@@ -6,12 +6,10 @@ import scipy.sparse ...@@ -6,12 +6,10 @@ import scipy.sparse
from torch_sparse.storage import SparseStorage, get_layout from torch_sparse.storage import SparseStorage, get_layout
# from torch_sparse.narrow import narrow
# from torch_sparse.select import select
# from torch_sparse.index_select import index_select, index_select_nnz # from torch_sparse.index_select import index_select, index_select_nnz
# from torch_sparse.masked_select import masked_select, masked_select_nnz # from torch_sparse.masked_select import masked_select, masked_select_nnz
# import torch_sparse.reduce
# from torch_sparse.diag import remove_diag, set_diag # from torch_sparse.diag import remove_diag, set_diag
# import torch_sparse.reduce
# from torch_sparse.matmul import matmul # from torch_sparse.matmul import matmul
# from torch_sparse.add import add, add_, add_nnz, add_nnz_ # from torch_sparse.add import add, add_, add_nnz, add_nnz_
# from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_ # from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
......
...@@ -5,32 +5,6 @@ from torch_sparse.storage import SparseStorage ...@@ -5,32 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
def transpose(index, value, m, n, coalesced=True):
"""Transposes dimensions 0 and 1 of a sparse tensor.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of corresponding dense matrix.
n (int): The second dimension of corresponding dense matrix.
coalesced (bool, optional): If set to :obj:`False`, will not coalesce
the output. (default: :obj:`True`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
if value.dim() == 1 and not value.is_cuda:
mat = to_scipy(index, value, m, n).tocsc()
(col, row), value = from_scipy(mat)
index = torch.stack([row, col], dim=0)
return index, value
row, col = index
index = torch.stack([col, row], dim=0)
if coalesced:
index, value = coalesce(index, value, n, m)
return index, value
@torch.jit.script @torch.jit.script
def t(src: SparseTensor): def t(src: SparseTensor):
csr2csc = src.storage.csr2csc() csr2csc = src.storage.csr2csc()
...@@ -60,3 +34,31 @@ def t(src: SparseTensor): ...@@ -60,3 +34,31 @@ def t(src: SparseTensor):
SparseTensor.t = lambda self: t(self) SparseTensor.t = lambda self: t(self)
###############################################################################
def transpose(index, value, m, n, coalesced=True):
"""Transposes dimensions 0 and 1 of a sparse tensor.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of corresponding dense matrix.
n (int): The second dimension of corresponding dense matrix.
coalesced (bool, optional): If set to :obj:`False`, will not coalesce
the output. (default: :obj:`True`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
if value.dim() == 1 and not value.is_cuda:
mat = to_scipy(index, value, m, n).tocsc()
(col, row), value = from_scipy(mat)
index = torch.stack([row, col], dim=0)
return index, value
row, col = index
index = torch.stack([col, row], dim=0)
if coalesced:
index, value = coalesce(index, value, n, m)
return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment