Commit 51834e88 authored by rusty1s's avatar rusty1s
Browse files

sparse format

parent a1f64207
import inspect
from textwrap import indent
import torch
from torch_sparse.storage import SparseStorage
methods = list(zip(*inspect.getmembers(SparseStorage)))[0]
methods = [name for name in methods if '__' not in name and name != 'clone']
class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
assert index.dtype == torch.long
assert index.dim() == 2 and index.size(0) == 2
index = index.contiguous()
if value is not None:
assert value.size(0) == index.size(1)
assert index.device == value.device
value = value.contiguous()
self._storage = SparseStorage(index[0], index[1], value, sparse_size,
is_sorted=is_sorted)
@classmethod
def from_storage(self, storage):
self = SparseTensor.__new__(SparseTensor)
self._storage = storage
return self
@classmethod
def from_dense(self, mat):
if mat.dim() > 2:
index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
else:
index = mat.nonzero()
index = index.t().contiguous()
value = mat[index[0], index[1]]
return SparseTensor(index, value, mat.size()[:2], is_sorted=True)
if sparse_size is None:
sparse_size = torch.Size((index.max(dim=-1)[0].cpu() + 1).tolist())
@property
def _storage(self):
return self.__storage
self.__index__ = index
self.__value__ = value
self.__sparse_size__ = sparse_size
@_storage.setter
def _storage(self, storage):
self.__storage = storage
for name in methods:
setattr(self, name, getattr(storage, name))
if not is_sorted and not self.__is_sorted__():
self.__sort__()
def clone(self):
return SparseTensor.from_storage(self._storage.clone())
def __copy__(self):
return self.clone()
def __deepcopy__(self, memo):
memo = memo.setdefault('SparseStorage', {})
if self._cdata in memo:
return memo[self._cdata]
new_sparse_tensor = self.clone()
memo[self._cdata] = new_sparse_tensor
return new_sparse_tensor
def coo(self):
return self._index, self._value
def csr(self):
return self._col, self._rowptr, self._value
def csc(self):
perm = self._arg_csr_to_csc
return self._row[perm], self._colptr, self._value[perm]
def is_quadratic(self):
return self.sparse_size[0] == self.sparse_size[1]
def is_symmetric(self):
if not self.is_quadratic:
return False
index1, value1 = self.coo()
index2, value2 = self.t().coo()
index_symmetric = (index1 == index2).all()
value_symmetric = (value1 == value2).all() if self.has_value else True
return index_symmetric and value_symmetric
def set_value(self, value, layout):
if value is not None and layout == 'csc':
value = value[self._arg_csc_to_csr]
return self._apply_value(value)
def set_value_(self, value, layout):
if value is not None and layout == 'csc':
value = value[self._arg_csc_to_csr]
return self._apply_value_(value)
def t(self):
storage = SparseStorage(
self._col[self._arg_csr_to_csc],
self._row[self._arg_csr_to_csc],
self._value[self._arg_csr_to_csc] if self.has_value else None,
self.sparse_size()[::-1],
self._colptr,
self._rowptr,
self._arg_csc_to_csr,
self._arg_csr_to_csc,
is_sorted=True,
)
return self.__class__.from_storage(storage)
def to(*args, **kwargs):
# TODO
def matmul(self, mat2):
pass
def size(self, dim=None):
size = self.__sparse_size__
size += () if self.__value__ is None else self.__value__.size()[1:]
return size if dim is None else size[dim]
def storage(self):
def coalesce(self, reduce='add'):
pass
@property
def shape(self):
return self.size()
def dim(self):
return len(self.size())
@property
def dtype(self):
return None if self.__value__ is None else self.__value__.dtype
@property
def device(self):
return self.__index__.device
def nnz(self):
return self.__index__.size(1)
def is_coalesced(self):
pass
def numel(self):
return self.__value__.numel() if self.__value__ else self.nnz()
def add(self, layout=None):
# sub, mul, div
# can take scalars, tensors and other sparse matrices
# inplace variants can only take scalars or tensors
pass
def clone(self):
return self.__class__(
index=self.__index__.clone(),
value=None if self.__value__ is None else self.__value__.clone(),
sparse_size=self.__sparse_size__,
is_sorted=True,
)
def to_dense(self, dtype=None):
dtype = dtype or self.dtype
mat = torch.zeros(self.size(), dtype=dtype, device=self.device)
mat[self._row, self._col] = self._value or 1
return mat
def sparse_resize_(self, *sizes):
assert len(sizes) == 2
self.__sparse_size__ = torch.Size(sizes)
def to_scipy(self):
raise NotImplementedError
def __is_sorted__(self):
idx1 = self.size(1) * index[0] + index[1]
idx2 = torch.cat([idx1.new_zeros(1), idx1[:-1]], dim=0)
return (idx1 >= idx2).all().item()
def to_torch_sparse_coo_tensor(self):
raise NotImplementedError
def __sort__(self):
idx = self.__sparse_size__(1) * self.__index__[0] + self.__index__[1]
perm = idx.argsort()
self.__index__ = index[:, perm]
self.__value__ = None if self.__value__ is None else self.__value__[
perm]
# TODO: Slicing, (sum|max|min|prod|...), standard operators, masing, perm
def __repr__(self):
i = ' ' * 6
infos = [f'index={indent(self.__index__.__repr__(), i)[len(i):]}']
if self.__value__ is not None:
infos += [f'value={indent(self.__value__.__repr__(), i)[len(i):]}']
infos += [f'size={tuple(self.size())}, nnz={self.nnz()}']
index, value = self.coo()
infos = [f'index={indent(index.__repr__(), i)[len(i):]}']
if value is not None:
infos += [f'value={indent(value.__repr__(), i)[len(i):]}']
infos += [
f'size={tuple(self.size())}, '
f'nnz={self.nnz()}, '
f'density={100 * self.density():.02f}%'
]
infos = ',\n'.join(infos)
i = ' ' * (len(self.__class__.__name__) + 1)
......@@ -94,14 +149,38 @@ class SparseTensor(object):
if __name__ == '__main__':
index = torch.tensor([
[0, 0, 1, 1, 2, 2],
[2, 1, 2, 3, 0, 1],
])
value = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float)
from torch_geometric.datasets import Reddit, Planetoid # noqa
import time # noqa
mat1 = SparseTensor(index, value)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
# dataset = Reddit('/tmp/Reddit')
dataset = Planetoid('/tmp/PubMed', 'PubMed')
data = dataset[0].to(device)
_bytes = data.edge_index.numel() * 8
_kbytes = _bytes / 1024
_mbytes = _kbytes / 1024
_gbytes = _mbytes / 1024
print(f'Storage: {_gbytes:.04f} GB')
mat1 = SparseTensor(data.edge_index)
print(mat1)
mat1 = mat1.t()
mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
device=device)
mat2 = mat2.coalesce()
mat2 = mat2.t().coalesce()
mat2 = torch.sparse_coo_tensor(index, value)
# print(mat2)
index1, value1 = mat1.coo()
index2, value2 = mat2._indices(), mat2._values()
assert torch.allclose(index1, index2)
out1 = mat1.to_dense()
out2 = mat2.to_dense()
assert torch.allclose(out1, out2)
mat1 = SparseTensor.from_dense(out1)
print(mat1)
......@@ -41,7 +41,7 @@ class SparseStorage(object):
rowptr = torch.cat([row.new_zeros(1), out_deg.cumsum(0)], dim=0)
else:
assert rowptr.dtype == torch.long and rowptr.device == row.device
assert rowptr.dim() == 1 and rowptr.size(0) == sparse_size[0] - 1
assert rowptr.dim() == 1 and rowptr.numel() - 1 == sparse_size[0]
if colptr is None:
ones = torch.ones_like(col) if ones is None else ones
......@@ -49,24 +49,24 @@ class SparseStorage(object):
colptr = torch.cat([col.new_zeros(1), in_deg.cumsum(0)], dim=0)
else:
assert colptr.dtype == torch.long and colptr.device == col.device
assert colptr.dim() == 1 and colptr.size(0) == sparse_size[1] - 1
assert colptr.dim() == 1 and colptr.numel() - 1 == sparse_size[1]
if arg_csr_to_csc is None:
idx = sparse_size[0] * col + row
arg_csr_to_csc = idx.argsort()
else:
assert arg_csr_to_csc == torch.long
assert arg_csr_to_csc.dtype == torch.long
assert arg_csr_to_csc.device == row.device
assert arg_csr_to_csc.dim() == 1
assert arg_csr_to_csc.size(0) == row.size(0)
assert arg_csr_to_csc.numel() == row.numel()
if arg_csc_to_csr is None:
arg_csc_to_csr = arg_csr_to_csc.argsort()
else:
assert arg_csc_to_csr == torch.long
assert arg_csc_to_csr.dtype == torch.long
assert arg_csc_to_csr.device == row.device
assert arg_csc_to_csr.dim() == 1
assert arg_csc_to_csr.size(0) == row.size(0)
assert arg_csc_to_csr.numel() == row.numel()
self.__row = row
self.__col = col
......@@ -85,6 +85,7 @@ class SparseStorage(object):
def _col(self):
return self.__col
@property
def _index(self):
return torch.stack([self.__row, self.__col], dim=0)
......@@ -120,6 +121,9 @@ class SparseStorage(object):
size += () if self.__value is None else self.__value.size()[1:]
return size if dim is None else size[dim]
def dim(self):
return len(self.size())
@property
def shape(self):
return self.size()
......@@ -129,20 +133,46 @@ class SparseStorage(object):
self.__sparse_size == sizes
return self
def nnz(self):
return self.__row.size(0)
def density(self):
return self.nnz() / (self.__sparse_size[0] * self.__sparse_size[1])
def sparsity(self):
return 1 - self.density()
def avg_row_length(self):
return self.nnz() / self.__sparse_size[0]
def avg_col_length(self):
return self.nnz() / self.__sparse_size[1]
def numel(self):
return self.nnz() if self.__value is None else self.__value.numel()
def clone(self):
return self.__apply(lambda x: x.clone())
return self._apply(lambda x: x.clone())
def __copy__(self):
return self.clone()
def __deepcopy__(self, memo):
memo = memo.setdefault('SparseStorage', {})
if self._cdata in memo:
return memo[self._cdata]
new_storage = self.clone()
memo[self._cdata] = new_storage
return new_storage
def pin_memory(self):
return self.__apply(lambda x: x.pin_memory())
return self._apply(lambda x: x.pin_memory())
def is_pinned(self):
return all([x.is_pinned for x in self.__attributes])
def share_memory_(self):
return self.__apply_(lambda x: x.share_memory_())
return self._apply_(lambda x: x.share_memory_())
def is_shared(self):
return all([x.is_shared for x in self.__attributes])
......@@ -152,10 +182,10 @@ class SparseStorage(object):
return self.__row.device
def cpu(self):
return self.__apply(lambda x: x.cpu())
return self._apply(lambda x: x.cpu())
def cuda(self, device=None, non_blocking=False, **kwargs):
return self.__apply(lambda x: x.cuda(device, non_blocking, **kwargs))
return self._apply(lambda x: x.cuda(device, non_blocking, **kwargs))
@property
def is_cuda(self):
......@@ -167,11 +197,12 @@ class SparseStorage(object):
def to(self, *args, **kwargs):
if 'device' in kwargs:
out = self.__apply(lambda x: x.to(kwargs['device']))
out = self._apply(lambda x: x.to(kwargs['device'], **kwargs))
del kwargs['device']
for arg in args[:]:
if isinstance(arg, str) or isinstance(arg, torch.device):
out = self.__apply(lambda x: x.to(arg))
out = self._apply(lambda x: x.to(arg, **kwargs))
args.remove(arg)
if len(args) > 0 and len(kwargs) > 0:
......@@ -180,72 +211,70 @@ class SparseStorage(object):
return out
def type(self, dtype=None, non_blocking=False, **kwargs):
return self.dtype if dtype is None else self.__apply_value(
return self.dtype if dtype is None else self._apply_value(
lambda x: x.type(dtype, non_blocking, **kwargs))
def is_floating_point(self):
return self.__value is None or torch.is_floating_point(self.__value)
def bfloat16(self):
return self.__apply_value(lambda x: x.bfloat16())
return self._apply_value(lambda x: x.bfloat16())
def bool(self):
return self.__apply_value(lambda x: x.bool())
return self._apply_value(lambda x: x.bool())
def byte(self):
return self.__apply_value(lambda x: x.byte())
return self._apply_value(lambda x: x.byte())
def char(self):
return self.__apply_value(lambda x: x.char())
return self._apply_value(lambda x: x.char())
def half(self):
return self.__apply_value(lambda x: x.half())
return self._apply_value(lambda x: x.half())
def float(self):
return self.__apply_value(lambda x: x.float())
return self._apply_value(lambda x: x.float())
def double(self):
return self.__apply_value(lambda x: x.double())
return self._apply_value(lambda x: x.double())
def short(self):
return self.__apply_value(lambda x: x.short())
return self._apply_value(lambda x: x.short())
def int(self):
return self.__apply_value(lambda x: x.int())
return self._apply_value(lambda x: x.int())
def long(self):
return self.__apply_value(lambda x: x.long())
###########################################################################
def __keys(self):
return inspect.getfullargspec(self.__init__)[0][1:-1]
return self._apply_value(lambda x: x.long())
def __state(self):
return {
key: getattr(self, f'_{self.__class__.__name__}__{key}')
for key in self.__keys()
for key in inspect.getfullargspec(self.__init__)[0][1:-1]
}
def __apply_value(self, func):
def _apply_value(self, func):
if self.__value is None:
return self
state = self.__state()
state['value'] == func(self.__value)
return self.__class__(is_sorted=True, **state)
def __apply_value_(self, func):
def _apply_value_(self, func):
self.__value = None if self.__value is None else func(self.__value)
return self
def __apply(self, func):
state = {key: func(item) for key, item in self.__state().items()}
def _apply(self, func):
state = self.__state().items()
state = {k: func(v) if torch.is_tensor(v) else v for k, v in state}
return self.__class__(is_sorted=True, **state)
def __apply_(self, func):
state = self.__state()
del state['value']
for key, item in self.__state().items():
setattr(self, f'_{self.__class__.__name__}__{key}', func(item))
return self.__apply_value_(func)
def _apply_(self, func):
for k, v in self.__state().items():
v = func(v) if torch.is_tensor(v) else v
setattr(self, f'_{self.__class__.__name__}__{k}', v)
return self
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment