Commit a1f64207 authored by rusty1s's avatar rusty1s
Browse files

storage functionality

parent b3746aab
import inspect
import torch import torch
from torch import Size from torch import Size
from torch_scatter import scatter_add, segment_add from torch_scatter import scatter_add, segment_add
...@@ -37,18 +39,34 @@ class SparseStorage(object): ...@@ -37,18 +39,34 @@ class SparseStorage(object):
ones = torch.ones_like(row) ones = torch.ones_like(row)
out_deg = segment_add(ones, row, dim=0, dim_size=sparse_size[0]) out_deg = segment_add(ones, row, dim=0, dim_size=sparse_size[0])
rowptr = torch.cat([row.new_zeros(1), out_deg.cumsum(0)], dim=0) rowptr = torch.cat([row.new_zeros(1), out_deg.cumsum(0)], dim=0)
else:
assert rowptr.dtype == torch.long and rowptr.device == row.device
assert rowptr.dim() == 1 and rowptr.size(0) == sparse_size[0] - 1
if colptr is None: if colptr is None:
ones = torch.ones_like(col) if ones is None else ones ones = torch.ones_like(col) if ones is None else ones
in_deg = scatter_add(ones, col, dim=0, dim_size=sparse_size[1]) in_deg = scatter_add(ones, col, dim=0, dim_size=sparse_size[1])
colptr = torch.cat([col.new_zeros(1), in_deg.cumsum(0)], dim=0) colptr = torch.cat([col.new_zeros(1), in_deg.cumsum(0)], dim=0)
else:
assert colptr.dtype == torch.long and colptr.device == col.device
assert colptr.dim() == 1 and colptr.size(0) == sparse_size[1] - 1
if arg_csr_to_csc is None: if arg_csr_to_csc is None:
idx = sparse_size[0] * col + row idx = sparse_size[0] * col + row
arg_csr_to_csc = idx.argsort() arg_csr_to_csc = idx.argsort()
else:
assert arg_csr_to_csc == torch.long
assert arg_csr_to_csc.device == row.device
assert arg_csr_to_csc.dim() == 1
assert arg_csr_to_csc.size(0) == row.size(0)
if arg_csr_to_csc is None: if arg_csc_to_csr is None:
arg_csc_to_csr = arg_csr_to_csc.argsort() arg_csc_to_csr = arg_csr_to_csc.argsort()
else:
assert arg_csc_to_csr == torch.long
assert arg_csc_to_csr.device == row.device
assert arg_csc_to_csr.dim() == 1
assert arg_csc_to_csr.size(0) == row.size(0)
self.__row = row self.__row = row
self.__col = col self.__col = col
...@@ -60,34 +78,34 @@ class SparseStorage(object): ...@@ -60,34 +78,34 @@ class SparseStorage(object):
self.__arg_csc_to_csr = arg_csc_to_csr self.__arg_csc_to_csr = arg_csc_to_csr
@property @property
def row(self): def _row(self):
return self.__row return self.__row
@property @property
def col(self): def _col(self):
return self.__col return self.__col
def index(self): def _index(self):
return torch.stack([self.__row, self.__col], dim=0) return torch.stack([self.__row, self.__col], dim=0)
@property @property
def rowptr(self): def _rowptr(self):
return self.__rowptr return self.__rowptr
@property @property
def colptr(self): def _colptr(self):
return self.__colptr return self.__colptr
@property @property
def arg_csr_to_csc(self): def _arg_csr_to_csc(self):
return self.__arg_csr_to_csc return self.__arg_csr_to_csc
@property @property
def arg_csc_to_csr(self): def _arg_csc_to_csr(self):
return self.__arg_csc_to_csr return self.__arg_csc_to_csr
@property @property
def value(self): def _value(self):
return self.__value return self.__value
@property @property
...@@ -99,7 +117,7 @@ class SparseStorage(object): ...@@ -99,7 +117,7 @@ class SparseStorage(object):
def size(self, dim=None): def size(self, dim=None):
size = self.__sparse_size size = self.__sparse_size
size += () if self.has_value is None else self.__value.size()[1:] size += () if self.__value is None else self.__value.size()[1:]
return size if dim is None else size[dim] return size if dim is None else size[dim]
@property @property
...@@ -109,102 +127,125 @@ class SparseStorage(object): ...@@ -109,102 +127,125 @@ class SparseStorage(object):
def sparse_resize_(self, *sizes): def sparse_resize_(self, *sizes):
assert len(sizes) == 2 assert len(sizes) == 2
self.__sparse_size == sizes self.__sparse_size == sizes
return self
def clone(self): def clone(self):
raise NotImplementedError return self.__apply(lambda x: x.clone())
def copy_(self): def __copy__(self):
raise NotImplementedError return self.clone()
def pin_memory(self): def pin_memory(self):
raise NotImplementedError return self.__apply(lambda x: x.pin_memory())
def is_pinned(self): def is_pinned(self):
raise NotImplementedError return all([x.is_pinned for x in self.__attributes])
def share_memory_(self): def share_memory_(self):
raise NotImplementedError return self.__apply_(lambda x: x.share_memory_())
def is_shared(self): def is_shared(self):
raise NotImplementedError return all([x.is_shared for x in self.__attributes])
@property @property
def device(self): def device(self):
return self.__row.device return self.__row.device
def cpu(self): def cpu(self):
pass return self.__apply(lambda x: x.cpu())
def cuda(device=None, non_blocking=False, **kwargs): def cuda(self, device=None, non_blocking=False, **kwargs):
pass return self.__apply(lambda x: x.cuda(device, non_blocking, **kwargs))
@property @property
def is_cuda(self): def is_cuda(self):
pass return self.__row.is_cuda
@property @property
def dtype(self): def dtype(self):
pass return None if self.__value is None else self.__value.dtype
def to(self, *args, **kwargs):
if 'device' in kwargs:
out = self.__apply(lambda x: x.to(kwargs['device']))
del kwargs['device']
for arg in args[:]:
if isinstance(arg, str) or isinstance(arg, torch.device):
out = self.__apply(lambda x: x.to(arg))
args.remove(arg)
def type(dtype=None, non_blocking=False, **kwargs): if len(args) > 0 and len(kwargs) > 0:
pass out = self.type(*args, **kwargs)
return out
def type(self, dtype=None, non_blocking=False, **kwargs):
return self.dtype if dtype is None else self.__apply_value(
lambda x: x.type(dtype, non_blocking, **kwargs))
def is_floating_point(self): def is_floating_point(self):
pass return self.__value is None or torch.is_floating_point(self.__value)
def bfloat16(self): def bfloat16(self):
pass return self.__apply_value(lambda x: x.bfloat16())
def bool(self): def bool(self):
pass return self.__apply_value(lambda x: x.bool())
def byte(self): def byte(self):
pass return self.__apply_value(lambda x: x.byte())
def char(self): def char(self):
pass return self.__apply_value(lambda x: x.char())
def half(self): def half(self):
pass return self.__apply_value(lambda x: x.half())
def float(self): def float(self):
pass return self.__apply_value(lambda x: x.float())
def double(self): def double(self):
pass return self.__apply_value(lambda x: x.double())
def short(self): def short(self):
pass return self.__apply_value(lambda x: x.short())
def int(self): def int(self):
pass return self.__apply_value(lambda x: x.int())
def long(self): def long(self):
pass return self.__apply_value(lambda x: x.long())
###########################################################################
def __apply_index(self, func): def __keys(self):
pass return inspect.getfullargspec(self.__init__)[0][1:-1]
def __apply_index_(self, func): def __state(self):
self.__row = func(self.__row) return {
self.__col = func(self.__col) key: getattr(self, f'_{self.__class__.__name__}__{key}')
self.__rowptr = func(self.__rowptr) for key in self.__keys()
self.__colptr = func(self.__colptr) }
self.__arg_csr_to_csc = func(self.__arg_csr_to_csc)
self.__arg_csc_to_csr = func(self.__arg_csc_to_csr)
def __apply_value(self, func): def __apply_value(self, func):
pass state = self.__state()
state['value'] == func(self.__value)
return self.__class__(is_sorted=True, **state)
def __apply_value_(self, func): def __apply_value_(self, func):
self.__value = func(self.__value) if self.has_value else None self.__value = None if self.__value is None else func(self.__value)
return self
def __apply(self, func): def __apply(self, func):
pass state = {key: func(item) for key, item in self.__state().items()}
return self.__class__(is_sorted=True, **state)
def __apply_(self, func): def __apply_(self, func):
self.__apply_index_(func) state = self.__state()
self.__apply_value_(func) del state['value']
for key, item in self.__state().items():
setattr(self, f'_{self.__class__.__name__}__{key}', func(item))
return self.__apply_value_(func)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -219,14 +260,3 @@ if __name__ == '__main__': ...@@ -219,14 +260,3 @@ if __name__ == '__main__':
row, col = edge_index row, col = edge_index
storage = SparseStorage(row, col) storage = SparseStorage(row, col)
# idx = data.num_nodes * col + row
# perm = idx.argsort()
# row, col = row[perm], col[perm]
# print(row[:20])
# print(col[:20])
# print('--------')
# perm = perm.argsort()
# row, col = row[perm], col[perm]
# print(row[:20])
# print(col[:20])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment