"vscode:/vscode.git/clone" did not exist on "9a7ae77a4eda5b4f819fd22ce9b713fb79993201"
Commit a25d8256 authored by rusty1s's avatar rusty1s
Browse files

more functionality

parent 36d045fd
import warnings
import inspect import inspect
from textwrap import indent from textwrap import indent
import torch import torch
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
...@@ -8,11 +10,15 @@ methods = list(zip(*inspect.getmembers(SparseStorage)))[0] ...@@ -8,11 +10,15 @@ methods = list(zip(*inspect.getmembers(SparseStorage)))[0]
methods = [name for name in methods if '__' not in name and name != 'clone'] methods = [name for name in methods if '__' not in name and name != 'clone']
def __is_scalar__(x):
return isinstance(x, int) or isinstance(x, float)
class SparseTensor(object): class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False): def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
assert index.dim() == 2 and index.size(0) == 2 assert index.dim() == 2 and index.size(0) == 2
self._storage = SparseStorage(index[0], index[1], value, sparse_size, self._storage = SparseStorage(
is_sorted=is_sorted) index[0], index[1], value, sparse_size, is_sorted=is_sorted)
@classmethod @classmethod
def from_storage(self, storage): def from_storage(self, storage):
...@@ -78,16 +84,29 @@ class SparseTensor(object): ...@@ -78,16 +84,29 @@ class SparseTensor(object):
value_symmetric = (value1 == value2).all() if self.has_value else True value_symmetric = (value1 == value2).all() if self.has_value else True
return index_symmetric and value_symmetric return index_symmetric and value_symmetric
def set_value(self, value, layout): def set_value(self, value, layout=None):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc']
if value is not None and layout == 'csc': if value is not None and layout == 'csc':
value = value[self._arg_csc_to_csr] value = value[self._arg_csc_to_csr]
return self._apply_value(value) return self._apply_value(value)
def set_value_(self, value, layout): def set_value_(self, value, layout=None):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc']
if value is not None and layout == 'csc': if value is not None and layout == 'csc':
value = value[self._arg_csc_to_csr] value = value[self._arg_csc_to_csr]
return self._apply_value_(value) return self._apply_value_(value)
def set_diag(self, value):
raise NotImplementedError
def t(self): def t(self):
storage = SparseStorage( storage = SparseStorage(
self._col[self._arg_csr_to_csc], self._col[self._arg_csr_to_csc],
...@@ -102,22 +121,119 @@ class SparseTensor(object): ...@@ -102,22 +121,119 @@ class SparseTensor(object):
) )
return self.__class__.from_storage(storage) return self.__class__.from_storage(storage)
def matmul(self, mat2):
raise NotImplementedError
def coalesce(self, reduce='add'): def coalesce(self, reduce='add'):
raise NotImplementedError raise NotImplementedError
def is_coalesced(self): def is_coalesced(self):
raise NotImplementedError raise NotImplementedError
def add(self, layout=None): def masked_select(self, mask):
# sub, mul, div raise NotImplementedError
# can take scalars, tensors and other sparse matrices
# inplace variants can only take scalars or tensors def index_select(self, index):
raise NotImplementedError
def select(self, dim, index):
raise NotImplementedError
def filter(self, index):
assert self.is_symmetric
assert index.dtype == torch.long or index.dtype == torch.bool
raise NotImplementedError
def permute(self, index):
assert index.dtype == torch.long
return self.filter(index)
def __getitem__(self, idx):
# Convert int and slice to index tensor
# Filter list into edge and sparse slice
raise NotImplementedError
def __reduce(self, dim, reduce, only_nnz):
raise NotImplementedError
def sum(self, dim):
return self.__reduce(dim, reduce='add', only_nnz=True)
def prod(self, dim):
return self.__reduce(dim, reduce='mul', only_nnz=True)
def min(self, dim, only_nnz=False):
return self.__reduce(dim, reduce='min', only_nnz=only_nnz)
def max(self, dim, only_nnz=False):
return self.__reduce(dim, reduce='min', only_nnz=only_nnz)
def mean(self, dim, only_nnz=False):
return self.__reduce(dim, reduce='mean', only_nnz=only_nnz)
def matmul(self, mat, reduce='add'):
assert self.numel() == self.nnz() # Disallow multi-dimensional value
if torch.is_tensor(mat):
raise NotImplementedError
elif isinstance(mat, self.__class__):
assert reduce == 'add'
assert mat.numel() == mat.nnz() # Disallow multi-dimensional value
raise NotImplementedError
raise ValueError('Argument needs to be of type `torch.tensor` or '
'type `torch_sparse.SparseTensor`.')
def add(self, other, layout=None):
if __is_scalar__(other):
if self.has_value:
return self.set_value(self._value + other, 'coo')
else:
return self.set_value(
torch.full((self.nnz(), ), other + 1), 'coo')
elif torch.is_tensor(other):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc']
if layout == 'csc':
other = other[self._arg_csc_to_csr]
if self.has_value:
return self.set_value(self._value + other, 'coo')
else:
return self.set_value(other + 1, 'coo')
elif isinstance(other, self.__class__):
raise NotImplementedError raise NotImplementedError
raise ValueError('Argument needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.')
# TODO: Slicing, (sum|max|min|prod|...), standard operators, masing, perm def add_(self, other, layout=None):
if isinstance(other, int) or isinstance(other, float):
raise NotImplementedError
elif torch.is_tensor(other):
raise NotImplementedError
raise ValueError('Argument needs to be a scalar or of type '
'`torch.tensor`.')
def __add__(self, other):
return self.add(other)
def __radd__(self, other):
return self.add(other)
def sub(self, layout=None):
raise NotImplementedError
def sub_(self, layout=None):
raise NotImplementedError
def mul(self, layout=None):
raise NotImplementedError
def mul_(self, layout=None):
raise NotImplementedError
def div(self, layout=None):
raise NotImplementedError
def div_(self, layout=None):
raise NotImplementedError
def to_dense(self, dtype=None): def to_dense(self, dtype=None):
dtype = dtype or self.dtype dtype = dtype or self.dtype
...@@ -125,11 +241,17 @@ class SparseTensor(object): ...@@ -125,11 +241,17 @@ class SparseTensor(object):
mat[self._row, self._col] = self._value if self.has_value else 1 mat[self._row, self._col] = self._value if self.has_value else 1
return mat return mat
def to_scipy(self): def to_scipy(self, layout):
raise NotImplementedError raise NotImplementedError
def to_torch_sparse_coo_tensor(self): def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
raise NotImplementedError index, value = self.coo()
return torch.sparse_coo_tensor(
index,
torch.ones_like(self._row, dtype) if value is None else value,
self.size(),
device=self.device,
requires_grad=requires_grad)
def __repr__(self): def __repr__(self):
i = ' ' * 6 i = ' ' * 6
...@@ -156,7 +278,8 @@ if __name__ == '__main__': ...@@ -156,7 +278,8 @@ if __name__ == '__main__':
device = 'cpu' device = 'cpu'
# dataset = Reddit('/tmp/Reddit') # dataset = Reddit('/tmp/Reddit')
dataset = Planetoid('/tmp/PubMed', 'PubMed') dataset = Planetoid('/tmp/Cora', 'Cora')
# dataset = Planetoid('/tmp/PubMed', 'PubMed')
data = dataset[0].to(device) data = dataset[0].to(device)
_bytes = data.edge_index.numel() * 8 _bytes = data.edge_index.numel() * 8
...@@ -169,8 +292,8 @@ if __name__ == '__main__': ...@@ -169,8 +292,8 @@ if __name__ == '__main__':
print(mat1) print(mat1)
mat1 = mat1.t() mat1 = mat1.t()
mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges), mat2 = torch.sparse_coo_tensor(
device=device) data.edge_index, torch.ones(data.num_edges), device=device)
mat2 = mat2.coalesce() mat2 = mat2.coalesce()
mat2 = mat2.t().coalesce() mat2 = mat2.t().coalesce()
...@@ -182,5 +305,12 @@ if __name__ == '__main__': ...@@ -182,5 +305,12 @@ if __name__ == '__main__':
out2 = mat2.to_dense() out2 = mat2.to_dense()
assert torch.allclose(out1, out2) assert torch.allclose(out1, out2)
mat1 = SparseTensor.from_dense(out1) out = 2 + mat1
print(mat1) print(out)
# mat1[1]
# mat1[1, 1]
# mat1[..., -1]
# mat1[:, -1]
# mat1[1:4, 1:4]
# mat1[torch.tensor([0, 1, 2])]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment