"vscode:/vscode.git/clone" did not exist on "368958df6f79da805bed6178b90bf1ca76e5d57b"
Commit a25d8256 authored by rusty1s's avatar rusty1s
Browse files

more functionality

parent 36d045fd
import warnings
import inspect
from textwrap import indent
import torch
from torch_sparse.storage import SparseStorage
......@@ -8,11 +10,15 @@ methods = list(zip(*inspect.getmembers(SparseStorage)))[0]
methods = [name for name in methods if '__' not in name and name != 'clone']
def __is_scalar__(x):
return isinstance(x, int) or isinstance(x, float)
class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
assert index.dim() == 2 and index.size(0) == 2
self._storage = SparseStorage(index[0], index[1], value, sparse_size,
is_sorted=is_sorted)
self._storage = SparseStorage(
index[0], index[1], value, sparse_size, is_sorted=is_sorted)
@classmethod
def from_storage(self, storage):
......@@ -78,16 +84,29 @@ class SparseTensor(object):
value_symmetric = (value1 == value2).all() if self.has_value else True
return index_symmetric and value_symmetric
def set_value(self, value, layout):
def set_value(self, value, layout=None):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc']
if value is not None and layout == 'csc':
value = value[self._arg_csc_to_csr]
return self._apply_value(value)
def set_value_(self, value, layout):
def set_value_(self, value, layout=None):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc']
if value is not None and layout == 'csc':
value = value[self._arg_csc_to_csr]
return self._apply_value_(value)
def set_diag(self, value):
raise NotImplementedError
def t(self):
storage = SparseStorage(
self._col[self._arg_csr_to_csc],
......@@ -102,22 +121,119 @@ class SparseTensor(object):
)
return self.__class__.from_storage(storage)
def matmul(self, mat2):
raise NotImplementedError
def coalesce(self, reduce='add'):
raise NotImplementedError
def is_coalesced(self):
raise NotImplementedError
def add(self, layout=None):
# sub, mul, div
# can take scalars, tensors and other sparse matrices
# inplace variants can only take scalars or tensors
def masked_select(self, mask):
raise NotImplementedError
def index_select(self, index):
raise NotImplementedError
# TODO: Slicing, (sum|max|min|prod|...), standard operators, masing, perm
def select(self, dim, index):
raise NotImplementedError
def filter(self, index):
assert self.is_symmetric
assert index.dtype == torch.long or index.dtype == torch.bool
raise NotImplementedError
def permute(self, index):
assert index.dtype == torch.long
return self.filter(index)
def __getitem__(self, idx):
# Convert int and slice to index tensor
# Filter list into edge and sparse slice
raise NotImplementedError
def __reduce(self, dim, reduce, only_nnz):
raise NotImplementedError
def sum(self, dim):
return self.__reduce(dim, reduce='add', only_nnz=True)
def prod(self, dim):
return self.__reduce(dim, reduce='mul', only_nnz=True)
def min(self, dim, only_nnz=False):
return self.__reduce(dim, reduce='min', only_nnz=only_nnz)
def max(self, dim, only_nnz=False):
return self.__reduce(dim, reduce='min', only_nnz=only_nnz)
def mean(self, dim, only_nnz=False):
return self.__reduce(dim, reduce='mean', only_nnz=only_nnz)
def matmul(self, mat, reduce='add'):
assert self.numel() == self.nnz() # Disallow multi-dimensional value
if torch.is_tensor(mat):
raise NotImplementedError
elif isinstance(mat, self.__class__):
assert reduce == 'add'
assert mat.numel() == mat.nnz() # Disallow multi-dimensional value
raise NotImplementedError
raise ValueError('Argument needs to be of type `torch.tensor` or '
'type `torch_sparse.SparseTensor`.')
def add(self, other, layout=None):
if __is_scalar__(other):
if self.has_value:
return self.set_value(self._value + other, 'coo')
else:
return self.set_value(
torch.full((self.nnz(), ), other + 1), 'coo')
elif torch.is_tensor(other):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc']
if layout == 'csc':
other = other[self._arg_csc_to_csr]
if self.has_value:
return self.set_value(self._value + other, 'coo')
else:
return self.set_value(other + 1, 'coo')
elif isinstance(other, self.__class__):
raise NotImplementedError
raise ValueError('Argument needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.')
def add_(self, other, layout=None):
if isinstance(other, int) or isinstance(other, float):
raise NotImplementedError
elif torch.is_tensor(other):
raise NotImplementedError
raise ValueError('Argument needs to be a scalar or of type '
'`torch.tensor`.')
def __add__(self, other):
return self.add(other)
def __radd__(self, other):
return self.add(other)
def sub(self, layout=None):
raise NotImplementedError
def sub_(self, layout=None):
raise NotImplementedError
def mul(self, layout=None):
raise NotImplementedError
def mul_(self, layout=None):
raise NotImplementedError
def div(self, layout=None):
raise NotImplementedError
def div_(self, layout=None):
raise NotImplementedError
def to_dense(self, dtype=None):
dtype = dtype or self.dtype
......@@ -125,11 +241,17 @@ class SparseTensor(object):
mat[self._row, self._col] = self._value if self.has_value else 1
return mat
def to_scipy(self):
def to_scipy(self, layout):
raise NotImplementedError
def to_torch_sparse_coo_tensor(self):
raise NotImplementedError
def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
index, value = self.coo()
return torch.sparse_coo_tensor(
index,
torch.ones_like(self._row, dtype) if value is None else value,
self.size(),
device=self.device,
requires_grad=requires_grad)
def __repr__(self):
i = ' ' * 6
......@@ -156,7 +278,8 @@ if __name__ == '__main__':
device = 'cpu'
# dataset = Reddit('/tmp/Reddit')
dataset = Planetoid('/tmp/PubMed', 'PubMed')
dataset = Planetoid('/tmp/Cora', 'Cora')
# dataset = Planetoid('/tmp/PubMed', 'PubMed')
data = dataset[0].to(device)
_bytes = data.edge_index.numel() * 8
......@@ -169,8 +292,8 @@ if __name__ == '__main__':
print(mat1)
mat1 = mat1.t()
mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
device=device)
mat2 = torch.sparse_coo_tensor(
data.edge_index, torch.ones(data.num_edges), device=device)
mat2 = mat2.coalesce()
mat2 = mat2.t().coalesce()
......@@ -182,5 +305,12 @@ if __name__ == '__main__':
out2 = mat2.to_dense()
assert torch.allclose(out1, out2)
mat1 = SparseTensor.from_dense(out1)
print(mat1)
out = 2 + mat1
print(out)
# mat1[1]
# mat1[1, 1]
# mat1[..., -1]
# mat1[:, -1]
# mat1[1:4, 1:4]
# mat1[torch.tensor([0, 1, 2])]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment