Commit 2ae73b17 authored by rusty1s's avatar rusty1s
Browse files

new storage format

parent fa763bac
......@@ -68,89 +68,114 @@ def get_layout(layout=None):
class SparseStorage(object):
cache_keys = [
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
]
cache_keys = ['rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr']
def __init__(self, index, value=None, sparse_size=None, rowcount=None,
rowptr=None, colcount=None, colptr=None, csr2csc=None,
csc2csr=None, is_sorted=False):
def __init__(self, row=None, rowptr=None, col=None, value=None,
sparse_size=None, rowcount=None, colptr=None, colcount=None,
csr2csc=None, csc2csr=None, is_sorted=False):
assert index.dtype == torch.long
assert index.dim() == 2 and index.size(0) == 2
index = index.contiguous()
if value is not None:
assert value.device == index.device
assert value.size(0) == index.size(1)
value = value.contiguous()
assert row is not None or rowptr is not None
assert col is not None
assert col.dtype == torch.long
assert col.dim() == 1
if sparse_size is None:
sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist())
M = rowptr.numel() - 1 if rowptr is None else row.max().item() + 1
N = col.max().item() + 1
sparse_size = torch.Size([M, N])
if rowcount is not None:
assert rowcount.dtype == torch.long
assert rowcount.device == index.device
assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0]
if row is not None:
assert row.dtype == torch.long
assert row.device == col.device
assert row.dim() == 1
assert row.numel() == col.numel()
if rowptr is not None:
assert rowptr.dtype == torch.long
assert rowptr.device == index.device
assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0]
assert rowptr.device == col.device
assert rowptr.dim() == 1
assert rowptr.numel() - 1 == sparse_size[0]
if colcount is not None:
assert colcount.dtype == torch.long
assert colcount.device == index.device
assert colcount.dim() == 1 and colcount.numel() == sparse_size[1]
if value is not None:
assert value.device == col.device
assert value.size(0) == col.size(0)
value = value.contiguous()
if rowcount is not None:
assert rowcount.dtype == torch.long
assert rowcount.device == col.device
assert rowcount.dim() == 1
assert rowcount.numel() == sparse_size[0]
if colptr is not None:
assert colptr.dtype == torch.long
assert colptr.device == index.device
assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1]
assert colptr.device == col.device
assert colptr.dim() == 1
assert colptr.numel() - 1 == sparse_size[1]
if colcount is not None:
assert colcount.dtype == torch.long
assert colcount.device == col.device
assert colcount.dim() == 1
assert colcount.numel() == sparse_size[1]
if csr2csc is not None:
assert csr2csc.dtype == torch.long
assert csr2csc.device == index.device
assert csr2csc.device == col.device
assert csr2csc.dim() == 1
assert csr2csc.numel() == index.size(1)
assert csr2csc.numel() == col.size(0)
if csc2csr is not None:
assert csc2csr.dtype == torch.long
assert csc2csr.device == index.device
assert csc2csr.device == col.device
assert csc2csr.dim() == 1
assert csc2csr.numel() == index.size(1)
assert csc2csr.numel() == col.size(0)
if not is_sorted:
idx = sparse_size[1] * index[0] + index[1]
# Only sort if necessary...
if (idx < torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
perm = idx.argsort()
index = index[:, perm]
value = None if value is None else value[perm]
csr2csc = None
csc2csr = None
self._index = index
self._row = row
self._rowptr = rowptr
self._col = col
self._value = value
self._sparse_size = sparse_size
self._rowcount = rowcount
self._rowptr = rowptr
self._colcount = colcount
self._colptr = colptr
self._colcount = colcount
self._csr2csc = csr2csc
self._csc2csr = csc2csr
@property
def index(self):
return self._index
if not is_sorted:
idx = self.col.new_zeros(col.numel() + 1)
idx[1:] = sparse_size[1] * self.row + self.col
if (idx[1:] < idx[:-1]).any():
perm = idx.argsort()
self._row = self.row[perm]
self._col = self.col[perm]
self._value = self.value[perm] if self.has_value() else None
self._csr2csc = None
self._csc2csr = None
def has_row(self):
return self._row is not None
@property
def row(self):
return self._index[0]
if self._row is None:
# TODO
pass
return self._row
def has_rowptr(self):
return self._rowptr is not None
@property
def rowptr(self):
if self._rowptr is None:
func = rowptr_cuda if self.row.is_cuda else rowptr_cpu
self._rowptr = func.rowptr(self.row, self.sparse_size[0])
return self._rowptr
@property
def col(self):
return self._index[1]
return self._col
def has_value(self):
return self._value is not None
......@@ -159,99 +184,99 @@ class SparseStorage(object):
def value(self):
return self._value
def set_value_(self, value, layout=None):
def set_value_(self, value, dtype=None, layout=None):
if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.nnz(), ), device=self.index.device)
value = torch.full((self.nnz(), ), dtype=dtype,
device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr]
if torch.is_tensor(value):
assert value.device == self.index.device
assert value.size(0) == self.index.size(1)
value = value if dtype is None else value.to(dtype)
assert value.device == self.col.device
assert value.size(0) == self.col.numel()
self._value = value
return self
def set_value(self, value, layout=None):
def set_value(self, value, dtype=None, layout=None):
if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.nnz(), ), device=self.index.device)
value = torch.full((self.nnz(), ), dtype=dtype,
device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr]
if torch.is_tensor(value):
assert value.device == self._index.device
assert value.size(0) == self._index.size(1)
return self.__class__(
self._index,
value,
self._sparse_size,
self._rowcount,
self._rowptr,
self._colcount,
self._colptr,
self._csr2csc,
self._csc2csr,
is_sorted=True,
)
value = value if dtype is None else value.to(dtype)
assert value.device == self.col.device
assert value.size(0) == self.col.numel()
def sparse_size(self, dim=None):
return self._sparse_size if dim is None else self._sparse_size[dim]
return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
value=value, sparse_size=self._sparse_size,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
@property
def sparse_size(self):
return self._sparse_size
def sparse_resize(self, *sizes):
assert len(sizes) == 2
old_sizes, nnz = self.sparse_size(), self.nnz()
old_sparse_size, nnz = self.sparse_size, self.nnz()
diff_0 = sizes[0] - old_sizes[0]
diff_0 = sizes[0] - old_sparse_size[0]
rowcount, rowptr = self._rowcount, self._rowptr
if diff_0 > 0:
if self.has_rowcount():
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
if self.has_rowptr():
if rowptr is not None:
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
if rowcount is not None:
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
else:
if self.has_rowcount():
rowcount = rowcount[:-diff_0]
if self.has_rowptr():
if rowptr is not None:
rowptr = rowptr[:-diff_0]
if rowcount is not None:
rowcount = rowcount[:-diff_0]
diff_1 = sizes[1] - old_sizes[1]
diff_1 = sizes[1] - old_sparse_size[1]
colcount, colptr = self._colcount, self._colptr
if diff_1 > 0:
if self.has_colcount():
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
if self.has_colptr():
if colptr is not None:
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
if colcount is not None:
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
else:
if self.has_colcount():
colcount = colcount[:-diff_1]
if self.has_colptr():
if colptr is not None:
colptr = colptr[:-diff_1]
if colcount is not None:
colcount = colcount[:-diff_1]
return self.__class__(
self._index,
self._value,
sizes,
rowcount=rowcount,
rowptr=rowptr,
colcount=colcount,
colptr=colptr,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
)
return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
value=self.value, sparse_size=sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def has_rowcount(self):
return self._rowcount is not None
@cached_property
def rowcount(self):
rowptr = self.rowptr
return rowptr[1:] - rowptr[:-1]
return self.rowptr[1:] - self.rowptr[:-1]
def has_rowptr(self):
return self._rowptr is not None
def has_colptr(self):
return self._colptr is not None
@cached_property
def rowptr(self):
func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
return func.rowptr(self.row, self.sparse_size(0))
def colptr(self):
if self.has_csr2csc():
func = rowptr_cuda if self.col.is_cuda else rowptr_cpu
return func.rowptr(self.col[self.csr2csc], self.sparse_size[1])
else:
colptr = self.col.new_zeros(self.sparse_size[1] + 1)
torch.cumsum(self.colcount, dim=0, out=colptr[1:])
return colptr
def has_colcount(self):
return self._colcount is not None
......@@ -259,32 +284,17 @@ class SparseStorage(object):
@cached_property
def colcount(self):
if self.has_colptr():
colptr = self.colptr
return colptr[1:] - colptr[:-1]
else:
col, dim_size = self.col, self.sparse_size(1)
return scatter_add(torch.ones_like(col), col, dim_size=dim_size)
def has_colptr(self):
return self._colptr is not None
@cached_property
def colptr(self):
if self.has_csr2csc():
func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
return func.rowptr(self.col[self.csr2csc], self.sparse_size(1))
return self.colptr[1:] - self.colptr[:-1]
else:
colcount = self.colcount
colptr = colcount.new_zeros(colcount.size(0) + 1)
torch.cumsum(colcount, dim=0, out=colptr[1:])
return colptr
return scatter_add(torch.ones_like(self.col), self.col,
dim_size=self.sparse_size[1])
def has_csr2csc(self):
return self._csr2csc is not None
@cached_property
def csr2csc(self):
idx = self._sparse_size[0] * self.col + self.row
idx = self.sparse_size[0] * self.col + self.row
return idx.argsort()
def has_csc2csr(self):
......@@ -295,26 +305,29 @@ class SparseStorage(object):
return self.csr2csc.argsort()
def is_coalesced(self):
idx = self.sparse_size(1) * self.row + self.col
mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
return mask.all().item()
idx = self.col.new_zeros(self.col.numel() + 1)
idx[1:] = self.sparse_size[1] * self.row + self.col
return (idx[1:] > idx[:-1]).all().item()
def coalesce(self, reduce='add'):
idx = self.sparse_size(1) * self.row + self.col
mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
idx = self.col.new_zeros(self.col.numel() + 1)
idx[1:] = self.sparse_size[1] * self.row + self.col
mask = idx[1:] > idx[:-1]
if mask.all(): # Skip if indices are already coalesced.
return self
index = self.index[:, mask]
row = self.row[mask]
col = self.col[mask]
value = self.value
if self.has_value():
idx = mask.cumsum(0) - 1
idx = mask.cumsum(0).sub_(1)
value = segment_csr(idx, value, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value
return self.__class__(index, value, self.sparse_size(), is_sorted=True)
return self.__class__(row=row, col=col, value=value,
sparse_size=self.sparse_size, is_sorted=True)
def cached_keys(self):
return [
......@@ -323,7 +336,7 @@ class SparseStorage(object):
]
def fill_cache_(self, *args):
for arg in args or self.cache_keys:
for arg in args or self.cache_keys + ['row', 'rowptr']:
getattr(self, arg)
return self
......@@ -344,46 +357,48 @@ class SparseStorage(object):
return new_storage
def apply_value_(self, func):
self._value = optional(func, self._value)
self._value = optional(func, self.value)
return self
def apply_value(self, func):
return self.__class__(
self._index,
optional(func, self._value),
self._sparse_size,
self._rowcount,
self._rowptr,
self._colcount,
self._colptr,
self._csr2csc,
self._csc2csr,
is_sorted=True,
)
return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
value=optional(func, self.value),
sparse_size=self.sparse_size,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def apply_(self, func):
self._index = func(self._index)
self._value = optional(func, self._value)
self._row = optional(func, self._row)
self._rowptr = optional(func, self._rowptr)
self._col = func(self.col)
self._value = optional(func, self.value)
for key in self.cached_keys():
setattr(self, f'_{key}', func(getattr(self, f'_{key}')))
return self
def apply(self, func):
return self.__class__(
func(self._index),
optional(func, self._value),
self._sparse_size,
optional(func, self._rowcount),
optional(func, self._rowptr),
optional(func, self._colcount),
optional(func, self._colptr),
optional(func, self._csr2csc),
optional(func, self._csc2csr),
row=optional(func, self._row),
rowptr=optional(func, self._rowptr),
col=func(self.col),
value=optional(func, self.value),
sparse_size=self.sparse_size,
rowcount=optional(func, self._rowcount),
colptr=optional(func, self._colptr),
colcount=optional(func, self._colcount),
csr2csc=optional(func, self._csr2csc),
csc2csr=optional(func, self._csc2csr),
is_sorted=True,
)
def map(self, func):
data = [func(self.index)]
data = []
if self.has_row():
data += [func(self.row)]
if self.has_rowptr():
data += [func(self.rowptr)]
data += [func(self.col)]
if self.has_value():
data += [func(self.value)]
data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment