Commit 2ae73b17 authored by rusty1s's avatar rusty1s
Browse files

new storage format

parent fa763bac
...@@ -68,89 +68,114 @@ def get_layout(layout=None): ...@@ -68,89 +68,114 @@ def get_layout(layout=None):
class SparseStorage(object): class SparseStorage(object):
cache_keys = [ cache_keys = ['rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr']
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
]
def __init__(self, index, value=None, sparse_size=None, rowcount=None, def __init__(self, row=None, rowptr=None, col=None, value=None,
rowptr=None, colcount=None, colptr=None, csr2csc=None, sparse_size=None, rowcount=None, colptr=None, colcount=None,
csc2csr=None, is_sorted=False): csr2csc=None, csc2csr=None, is_sorted=False):
assert index.dtype == torch.long assert row is not None or rowptr is not None
assert index.dim() == 2 and index.size(0) == 2 assert col is not None
index = index.contiguous() assert col.dtype == torch.long
assert col.dim() == 1
if value is not None:
assert value.device == index.device
assert value.size(0) == index.size(1)
value = value.contiguous()
if sparse_size is None: if sparse_size is None:
sparse_size = torch.Size((index.max(dim=-1)[0] + 1).tolist()) M = rowptr.numel() - 1 if rowptr is None else row.max().item() + 1
N = col.max().item() + 1
sparse_size = torch.Size([M, N])
if rowcount is not None: if row is not None:
assert rowcount.dtype == torch.long assert row.dtype == torch.long
assert rowcount.device == index.device assert row.device == col.device
assert rowcount.dim() == 1 and rowcount.numel() == sparse_size[0] assert row.dim() == 1
assert row.numel() == col.numel()
if rowptr is not None: if rowptr is not None:
assert rowptr.dtype == torch.long assert rowptr.dtype == torch.long
assert rowptr.device == index.device assert rowptr.device == col.device
assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0] assert rowptr.dim() == 1
assert rowptr.numel() - 1 == sparse_size[0]
if colcount is not None: if value is not None:
assert colcount.dtype == torch.long assert value.device == col.device
assert colcount.device == index.device assert value.size(0) == col.size(0)
assert colcount.dim() == 1 and colcount.numel() == sparse_size[1] value = value.contiguous()
if rowcount is not None:
assert rowcount.dtype == torch.long
assert rowcount.device == col.device
assert rowcount.dim() == 1
assert rowcount.numel() == sparse_size[0]
if colptr is not None: if colptr is not None:
assert colptr.dtype == torch.long assert colptr.dtype == torch.long
assert colptr.device == index.device assert colptr.device == col.device
assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1] assert colptr.dim() == 1
assert colptr.numel() - 1 == sparse_size[1]
if colcount is not None:
assert colcount.dtype == torch.long
assert colcount.device == col.device
assert colcount.dim() == 1
assert colcount.numel() == sparse_size[1]
if csr2csc is not None: if csr2csc is not None:
assert csr2csc.dtype == torch.long assert csr2csc.dtype == torch.long
assert csr2csc.device == index.device assert csr2csc.device == col.device
assert csr2csc.dim() == 1 assert csr2csc.dim() == 1
assert csr2csc.numel() == index.size(1) assert csr2csc.numel() == col.size(0)
if csc2csr is not None: if csc2csr is not None:
assert csc2csr.dtype == torch.long assert csc2csr.dtype == torch.long
assert csc2csr.device == index.device assert csc2csr.device == col.device
assert csc2csr.dim() == 1 assert csc2csr.dim() == 1
assert csc2csr.numel() == index.size(1) assert csc2csr.numel() == col.size(0)
if not is_sorted: self._row = row
idx = sparse_size[1] * index[0] + index[1] self._rowptr = rowptr
# Only sort if necessary... self._col = col
if (idx < torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
perm = idx.argsort()
index = index[:, perm]
value = None if value is None else value[perm]
csr2csc = None
csc2csr = None
self._index = index
self._value = value self._value = value
self._sparse_size = sparse_size self._sparse_size = sparse_size
self._rowcount = rowcount self._rowcount = rowcount
self._rowptr = rowptr
self._colcount = colcount
self._colptr = colptr self._colptr = colptr
self._colcount = colcount
self._csr2csc = csr2csc self._csr2csc = csr2csc
self._csc2csr = csc2csr self._csc2csr = csc2csr
@property if not is_sorted:
def index(self): idx = self.col.new_zeros(col.numel() + 1)
return self._index idx[1:] = sparse_size[1] * self.row + self.col
if (idx[1:] < idx[:-1]).any():
perm = idx.argsort()
self._row = self.row[perm]
self._col = self.col[perm]
self._value = self.value[perm] if self.has_value() else None
self._csr2csc = None
self._csc2csr = None
def has_row(self):
return self._row is not None
@property @property
def row(self): def row(self):
return self._index[0] if self._row is None:
# TODO
pass
return self._row
def has_rowptr(self):
return self._rowptr is not None
@property
def rowptr(self):
if self._rowptr is None:
func = rowptr_cuda if self.row.is_cuda else rowptr_cpu
self._rowptr = func.rowptr(self.row, self.sparse_size[0])
return self._rowptr
@property @property
def col(self): def col(self):
return self._index[1] return self._col
def has_value(self): def has_value(self):
return self._value is not None return self._value is not None
...@@ -159,99 +184,99 @@ class SparseStorage(object): ...@@ -159,99 +184,99 @@ class SparseStorage(object):
def value(self): def value(self):
return self._value return self._value
def set_value_(self, value, layout=None): def set_value_(self, value, dtype=None, layout=None):
if isinstance(value, int) or isinstance(value, float): if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.nnz(), ), device=self.index.device) value = torch.full((self.nnz(), ), dtype=dtype,
device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc': elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr] value = value[self.csc2csr]
if torch.is_tensor(value): if torch.is_tensor(value):
assert value.device == self.index.device value = value if dtype is None else value.to(dtype)
assert value.size(0) == self.index.size(1) assert value.device == self.col.device
assert value.size(0) == self.col.numel()
self._value = value self._value = value
return self return self
def set_value(self, value, layout=None): def set_value(self, value, dtype=None, layout=None):
if isinstance(value, int) or isinstance(value, float): if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.nnz(), ), device=self.index.device) value = torch.full((self.nnz(), ), dtype=dtype,
device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc': elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr] value = value[self.csc2csr]
if torch.is_tensor(value): if torch.is_tensor(value):
assert value.device == self._index.device value = value if dtype is None else value.to(dtype)
assert value.size(0) == self._index.size(1) assert value.device == self.col.device
return self.__class__( assert value.size(0) == self.col.numel()
self._index,
value,
self._sparse_size,
self._rowcount,
self._rowptr,
self._colcount,
self._colptr,
self._csr2csc,
self._csc2csr,
is_sorted=True,
)
def sparse_size(self, dim=None): return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
return self._sparse_size if dim is None else self._sparse_size[dim] value=value, sparse_size=self._sparse_size,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
@property
def sparse_size(self):
return self._sparse_size
def sparse_resize(self, *sizes): def sparse_resize(self, *sizes):
assert len(sizes) == 2 old_sparse_size, nnz = self.sparse_size, self.nnz()
old_sizes, nnz = self.sparse_size(), self.nnz()
diff_0 = sizes[0] - old_sizes[0] diff_0 = sizes[0] - old_sparse_size[0]
rowcount, rowptr = self._rowcount, self._rowptr rowcount, rowptr = self._rowcount, self._rowptr
if diff_0 > 0: if diff_0 > 0:
if self.has_rowcount(): if rowptr is not None:
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
if self.has_rowptr():
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)]) rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
if rowcount is not None:
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
else: else:
if self.has_rowcount(): if rowptr is not None:
rowcount = rowcount[:-diff_0]
if self.has_rowptr():
rowptr = rowptr[:-diff_0] rowptr = rowptr[:-diff_0]
if rowcount is not None:
rowcount = rowcount[:-diff_0]
diff_1 = sizes[1] - old_sizes[1] diff_1 = sizes[1] - old_sparse_size[1]
colcount, colptr = self._colcount, self._colptr colcount, colptr = self._colcount, self._colptr
if diff_1 > 0: if diff_1 > 0:
if self.has_colcount(): if colptr is not None:
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
if self.has_colptr():
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)]) colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
if colcount is not None:
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
else: else:
if self.has_colcount(): if colptr is not None:
colcount = colcount[:-diff_1]
if self.has_colptr():
colptr = colptr[:-diff_1] colptr = colptr[:-diff_1]
if colcount is not None:
colcount = colcount[:-diff_1]
return self.__class__( return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
self._index, value=self.value, sparse_size=sizes,
self._value, rowcount=rowcount, colptr=colptr,
sizes, colcount=colcount, csr2csc=self._csr2csc,
rowcount=rowcount, csc2csr=self._csc2csr, is_sorted=True)
rowptr=rowptr,
colcount=colcount,
colptr=colptr,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
)
def has_rowcount(self): def has_rowcount(self):
return self._rowcount is not None return self._rowcount is not None
@cached_property @cached_property
def rowcount(self): def rowcount(self):
rowptr = self.rowptr return self.rowptr[1:] - self.rowptr[:-1]
return rowptr[1:] - rowptr[:-1]
def has_rowptr(self): def has_colptr(self):
return self._rowptr is not None return self._colptr is not None
@cached_property @cached_property
def rowptr(self): def colptr(self):
func = rowptr_cuda if self.index.is_cuda else rowptr_cpu if self.has_csr2csc():
return func.rowptr(self.row, self.sparse_size(0)) func = rowptr_cuda if self.col.is_cuda else rowptr_cpu
return func.rowptr(self.col[self.csr2csc], self.sparse_size[1])
else:
colptr = self.col.new_zeros(self.sparse_size[1] + 1)
torch.cumsum(self.colcount, dim=0, out=colptr[1:])
return colptr
def has_colcount(self): def has_colcount(self):
return self._colcount is not None return self._colcount is not None
...@@ -259,32 +284,17 @@ class SparseStorage(object): ...@@ -259,32 +284,17 @@ class SparseStorage(object):
@cached_property @cached_property
def colcount(self): def colcount(self):
if self.has_colptr(): if self.has_colptr():
colptr = self.colptr return self.colptr[1:] - self.colptr[:-1]
return colptr[1:] - colptr[:-1]
else:
col, dim_size = self.col, self.sparse_size(1)
return scatter_add(torch.ones_like(col), col, dim_size=dim_size)
def has_colptr(self):
return self._colptr is not None
@cached_property
def colptr(self):
if self.has_csr2csc():
func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
return func.rowptr(self.col[self.csr2csc], self.sparse_size(1))
else: else:
colcount = self.colcount return scatter_add(torch.ones_like(self.col), self.col,
colptr = colcount.new_zeros(colcount.size(0) + 1) dim_size=self.sparse_size[1])
torch.cumsum(colcount, dim=0, out=colptr[1:])
return colptr
def has_csr2csc(self): def has_csr2csc(self):
return self._csr2csc is not None return self._csr2csc is not None
@cached_property @cached_property
def csr2csc(self): def csr2csc(self):
idx = self._sparse_size[0] * self.col + self.row idx = self.sparse_size[0] * self.col + self.row
return idx.argsort() return idx.argsort()
def has_csc2csr(self): def has_csc2csr(self):
...@@ -295,26 +305,29 @@ class SparseStorage(object): ...@@ -295,26 +305,29 @@ class SparseStorage(object):
return self.csr2csc.argsort() return self.csr2csc.argsort()
def is_coalesced(self): def is_coalesced(self):
idx = self.sparse_size(1) * self.row + self.col idx = self.col.new_zeros(self.col.numel() + 1)
mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0) idx[1:] = self.sparse_size[1] * self.row + self.col
return mask.all().item() return (idx[1:] > idx[:-1]).all().item()
def coalesce(self, reduce='add'): def coalesce(self, reduce='add'):
idx = self.sparse_size(1) * self.row + self.col idx = self.col.new_zeros(self.col.numel() + 1)
mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0) idx[1:] = self.sparse_size[1] * self.row + self.col
mask = idx[1:] > idx[:-1]
if mask.all(): # Skip if indices are already coalesced. if mask.all(): # Skip if indices are already coalesced.
return self return self
index = self.index[:, mask] row = self.row[mask]
col = self.col[mask]
value = self.value value = self.value
if self.has_value(): if self.has_value():
idx = mask.cumsum(0) - 1 idx = mask.cumsum(0).sub_(1)
value = segment_csr(idx, value, reduce=reduce) value = segment_csr(idx, value, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value value = value[0] if isinstance(value, tuple) else value
return self.__class__(index, value, self.sparse_size(), is_sorted=True) return self.__class__(row=row, col=col, value=value,
sparse_size=self.sparse_size, is_sorted=True)
def cached_keys(self): def cached_keys(self):
return [ return [
...@@ -323,7 +336,7 @@ class SparseStorage(object): ...@@ -323,7 +336,7 @@ class SparseStorage(object):
] ]
def fill_cache_(self, *args): def fill_cache_(self, *args):
for arg in args or self.cache_keys: for arg in args or self.cache_keys + ['row', 'rowptr']:
getattr(self, arg) getattr(self, arg)
return self return self
...@@ -344,46 +357,48 @@ class SparseStorage(object): ...@@ -344,46 +357,48 @@ class SparseStorage(object):
return new_storage return new_storage
def apply_value_(self, func): def apply_value_(self, func):
self._value = optional(func, self._value) self._value = optional(func, self.value)
return self return self
def apply_value(self, func): def apply_value(self, func):
return self.__class__( return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
self._index, value=optional(func, self.value),
optional(func, self._value), sparse_size=self.sparse_size,
self._sparse_size, rowcount=self._rowcount, colptr=self._colptr,
self._rowcount, colcount=self._colcount, csr2csc=self._csr2csc,
self._rowptr, csc2csr=self._csc2csr, is_sorted=True)
self._colcount,
self._colptr,
self._csr2csc,
self._csc2csr,
is_sorted=True,
)
def apply_(self, func): def apply_(self, func):
self._index = func(self._index) self._row = optional(func, self._row)
self._value = optional(func, self._value) self._rowptr = optional(func, self._rowptr)
self._col = func(self.col)
self._value = optional(func, self.value)
for key in self.cached_keys(): for key in self.cached_keys():
setattr(self, f'_{key}', func(getattr(self, f'_{key}'))) setattr(self, f'_{key}', func(getattr(self, f'_{key}')))
return self return self
def apply(self, func): def apply(self, func):
return self.__class__( return self.__class__(
func(self._index), row=optional(func, self._row),
optional(func, self._value), rowptr=optional(func, self._rowptr),
self._sparse_size, col=func(self.col),
optional(func, self._rowcount), value=optional(func, self.value),
optional(func, self._rowptr), sparse_size=self.sparse_size,
optional(func, self._colcount), rowcount=optional(func, self._rowcount),
optional(func, self._colptr), colptr=optional(func, self._colptr),
optional(func, self._csr2csc), colcount=optional(func, self._colcount),
optional(func, self._csc2csr), csr2csc=optional(func, self._csr2csc),
csc2csr=optional(func, self._csc2csr),
is_sorted=True, is_sorted=True,
) )
def map(self, func): def map(self, func):
data = [func(self.index)] data = []
if self.has_row():
data += [func(self.row)]
if self.has_rowptr():
data += [func(self.rowptr)]
data += [func(self.col)]
if self.has_value(): if self.has_value():
data += [func(self.value)] data += [func(self.value)]
data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()] data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment