Commit a1f64207 authored by rusty1s's avatar rusty1s
Browse files

storage functionality

parent b3746aab
import inspect
import torch
from torch import Size
from torch_scatter import scatter_add, segment_add
......@@ -37,18 +39,34 @@ class SparseStorage(object):
ones = torch.ones_like(row)
out_deg = segment_add(ones, row, dim=0, dim_size=sparse_size[0])
rowptr = torch.cat([row.new_zeros(1), out_deg.cumsum(0)], dim=0)
else:
assert rowptr.dtype == torch.long and rowptr.device == row.device
assert rowptr.dim() == 1 and rowptr.size(0) == sparse_size[0] - 1
if colptr is None:
ones = torch.ones_like(col) if ones is None else ones
in_deg = scatter_add(ones, col, dim=0, dim_size=sparse_size[1])
colptr = torch.cat([col.new_zeros(1), in_deg.cumsum(0)], dim=0)
else:
assert colptr.dtype == torch.long and colptr.device == col.device
assert colptr.dim() == 1 and colptr.size(0) == sparse_size[1] - 1
if arg_csr_to_csc is None:
idx = sparse_size[0] * col + row
arg_csr_to_csc = idx.argsort()
else:
assert arg_csr_to_csc == torch.long
assert arg_csr_to_csc.device == row.device
assert arg_csr_to_csc.dim() == 1
assert arg_csr_to_csc.size(0) == row.size(0)
if arg_csr_to_csc is None:
if arg_csc_to_csr is None:
arg_csc_to_csr = arg_csr_to_csc.argsort()
else:
assert arg_csc_to_csr == torch.long
assert arg_csc_to_csr.device == row.device
assert arg_csc_to_csr.dim() == 1
assert arg_csc_to_csr.size(0) == row.size(0)
self.__row = row
self.__col = col
......@@ -60,34 +78,34 @@ class SparseStorage(object):
self.__arg_csc_to_csr = arg_csc_to_csr
@property
def row(self):
def _row(self):
return self.__row
@property
def col(self):
def _col(self):
return self.__col
def index(self):
def _index(self):
return torch.stack([self.__row, self.__col], dim=0)
@property
def rowptr(self):
def _rowptr(self):
return self.__rowptr
@property
def colptr(self):
def _colptr(self):
return self.__colptr
@property
def arg_csr_to_csc(self):
def _arg_csr_to_csc(self):
return self.__arg_csr_to_csc
@property
def arg_csc_to_csr(self):
def _arg_csc_to_csr(self):
return self.__arg_csc_to_csr
@property
def value(self):
def _value(self):
return self.__value
@property
......@@ -99,7 +117,7 @@ class SparseStorage(object):
def size(self, dim=None):
size = self.__sparse_size
size += () if self.has_value is None else self.__value.size()[1:]
size += () if self.__value is None else self.__value.size()[1:]
return size if dim is None else size[dim]
@property
......@@ -109,102 +127,125 @@ class SparseStorage(object):
def sparse_resize_(self, *sizes):
assert len(sizes) == 2
self.__sparse_size == sizes
return self
def clone(self):
raise NotImplementedError
return self.__apply(lambda x: x.clone())
def copy_(self):
raise NotImplementedError
def __copy__(self):
return self.clone()
def pin_memory(self):
raise NotImplementedError
return self.__apply(lambda x: x.pin_memory())
def is_pinned(self):
raise NotImplementedError
return all([x.is_pinned for x in self.__attributes])
def share_memory_(self):
raise NotImplementedError
return self.__apply_(lambda x: x.share_memory_())
def is_shared(self):
raise NotImplementedError
return all([x.is_shared for x in self.__attributes])
@property
def device(self):
return self.__row.device
def cpu(self):
pass
return self.__apply(lambda x: x.cpu())
def cuda(device=None, non_blocking=False, **kwargs):
pass
def cuda(self, device=None, non_blocking=False, **kwargs):
return self.__apply(lambda x: x.cuda(device, non_blocking, **kwargs))
@property
def is_cuda(self):
pass
return self.__row.is_cuda
@property
def dtype(self):
pass
return None if self.__value is None else self.__value.dtype
def to(self, *args, **kwargs):
if 'device' in kwargs:
out = self.__apply(lambda x: x.to(kwargs['device']))
del kwargs['device']
for arg in args[:]:
if isinstance(arg, str) or isinstance(arg, torch.device):
out = self.__apply(lambda x: x.to(arg))
args.remove(arg)
def type(dtype=None, non_blocking=False, **kwargs):
pass
if len(args) > 0 and len(kwargs) > 0:
out = self.type(*args, **kwargs)
return out
def type(self, dtype=None, non_blocking=False, **kwargs):
return self.dtype if dtype is None else self.__apply_value(
lambda x: x.type(dtype, non_blocking, **kwargs))
def is_floating_point(self):
pass
return self.__value is None or torch.is_floating_point(self.__value)
def bfloat16(self):
pass
return self.__apply_value(lambda x: x.bfloat16())
def bool(self):
pass
return self.__apply_value(lambda x: x.bool())
def byte(self):
pass
return self.__apply_value(lambda x: x.byte())
def char(self):
pass
return self.__apply_value(lambda x: x.char())
def half(self):
pass
return self.__apply_value(lambda x: x.half())
def float(self):
pass
return self.__apply_value(lambda x: x.float())
def double(self):
pass
return self.__apply_value(lambda x: x.double())
def short(self):
pass
return self.__apply_value(lambda x: x.short())
def int(self):
pass
return self.__apply_value(lambda x: x.int())
def long(self):
pass
return self.__apply_value(lambda x: x.long())
###########################################################################
def __apply_index(self, func):
pass
def __keys(self):
return inspect.getfullargspec(self.__init__)[0][1:-1]
def __apply_index_(self, func):
self.__row = func(self.__row)
self.__col = func(self.__col)
self.__rowptr = func(self.__rowptr)
self.__colptr = func(self.__colptr)
self.__arg_csr_to_csc = func(self.__arg_csr_to_csc)
self.__arg_csc_to_csr = func(self.__arg_csc_to_csr)
def __state(self):
return {
key: getattr(self, f'_{self.__class__.__name__}__{key}')
for key in self.__keys()
}
def __apply_value(self, func):
pass
state = self.__state()
state['value'] == func(self.__value)
return self.__class__(is_sorted=True, **state)
def __apply_value_(self, func):
self.__value = func(self.__value) if self.has_value else None
self.__value = None if self.__value is None else func(self.__value)
return self
def __apply(self, func):
pass
state = {key: func(item) for key, item in self.__state().items()}
return self.__class__(is_sorted=True, **state)
def __apply_(self, func):
self.__apply_index_(func)
self.__apply_value_(func)
state = self.__state()
del state['value']
for key, item in self.__state().items():
setattr(self, f'_{self.__class__.__name__}__{key}', func(item))
return self.__apply_value_(func)
if __name__ == '__main__':
......@@ -219,14 +260,3 @@ if __name__ == '__main__':
row, col = edge_index
storage = SparseStorage(row, col)
# idx = data.num_nodes * col + row
# perm = idx.argsort()
# row, col = row[perm], col[perm]
# print(row[:20])
# print(col[:20])
# print('--------')
# perm = perm.argsort()
# row, col = row[perm], col[perm]
# print(row[:20])
# print(col[:20])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment