Commit 57d37233 authored by rusty1s's avatar rusty1s
Browse files

better colptr computation

parent 0692b905
import warnings
import torch
from torch_scatter import segment_csr
from torch_scatter import segment_csr, scatter_add
from torch_sparse import rowptr_cpu
......@@ -70,17 +70,9 @@ class SparseStorage(object):
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
]
def __init__(self,
index,
value=None,
sparse_size=None,
rowcount=None,
rowptr=None,
colcount=None,
colptr=None,
csr2csc=None,
csc2csr=None,
is_sorted=False):
def __init__(self, index, value=None, sparse_size=None, rowcount=None,
rowptr=None, colcount=None, colptr=None, csr2csc=None,
csc2csr=None, is_sorted=False):
assert index.dtype == torch.long
assert index.dim() == 2 and index.size(0) == 2
......@@ -199,31 +191,59 @@ class SparseStorage(object):
self._sparse_size = sizes
return self
def has_rowcount(self):
return self._rowcount is not None
@cached_property
def rowcount(self):
rowptr = self.rowptr
return rowptr[1:] - rowptr[:-1]
def has_rowptr(self):
return self._rowptr is not None
@cached_property
def rowptr(self):
func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
return func.rowptr(self.row, self.sparse_size(0))
def has_colcount(self):
return self._colcount is not None
@cached_property
def colcount(self):
colptr = self.colptr
return colptr[1:] - colptr[:-1]
if self._colptr is not None:
colptr = self.colptr
return colptr[1:] - colptr[:-1]
else:
col, dim_size = self.col, self.sparse_size(1)
return scatter_add(torch.ones_like(col), col, dim_size=dim_size)
def has_colptr(self):
return self._colptr is not None
@cached_property
def colptr(self):
func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
return func.rowptr(self.col[self.csr2csc], self.sparse_size(1))
if self._csr2csc:
func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
return func.rowptr(self.col[self.csr2csc], self.sparse_size(1))
else:
colcount = self.colcount
colptr = colcount.new_zeros(colcount.size(0) + 1)
torch.cumsum(colcount, dim=0, out=colptr[1:])
return colptr
def has_csr2csc(self):
return self._csr2csc is not None
@cached_property
def csr2csc(self):
idx = self._sparse_size[0] * self.col + self.row
return idx.argsort()
def has_csc2csr(self):
return self._csc2csr is not None
@cached_property
def csc2csr(self):
return self.csr2csc.argsort()
......@@ -237,7 +257,7 @@ class SparseStorage(object):
idx = self.sparse_size(1) * self.row + self.col
mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
if mask.all(): # Already coalesced
if mask.all(): # Skip if indices are already coalesced.
return self
index = self.index[:, mask]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment