Commit 8b07240b authored by rusty1s's avatar rusty1s
Browse files

Merge branch 'adj' of github.com:rusty1s/pytorch_sparse into adj

parents 519306d3 a25d8256
import warnings
import inspect import inspect
from textwrap import indent from textwrap import indent
import torch import torch
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
...@@ -8,11 +10,15 @@ methods = list(zip(*inspect.getmembers(SparseStorage)))[0] ...@@ -8,11 +10,15 @@ methods = list(zip(*inspect.getmembers(SparseStorage)))[0]
methods = [name for name in methods if '__' not in name and name != 'clone'] methods = [name for name in methods if '__' not in name and name != 'clone']
def __is_scalar__(x):
return isinstance(x, int) or isinstance(x, float)
class SparseTensor(object): class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False): def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
assert index.dim() == 2 and index.size(0) == 2 assert index.dim() == 2 and index.size(0) == 2
self._storage = SparseStorage(index[0], index[1], value, sparse_size, self._storage = SparseStorage(
is_sorted=is_sorted) index[0], index[1], value, sparse_size, is_sorted=is_sorted)
@classmethod @classmethod
def from_storage(self, storage): def from_storage(self, storage):
...@@ -78,16 +84,29 @@ class SparseTensor(object): ...@@ -78,16 +84,29 @@ class SparseTensor(object):
value_symmetric = (value1 == value2).all() if self.has_value else True value_symmetric = (value1 == value2).all() if self.has_value else True
return index_symmetric and value_symmetric return index_symmetric and value_symmetric
def set_value(self, value, layout): def set_value(self, value, layout=None):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc']
if value is not None and layout == 'csc': if value is not None and layout == 'csc':
value = value[self._arg_csc_to_csr] value = value[self._arg_csc_to_csr]
return self._apply_value(value) return self._apply_value(value)
def set_value_(self, value, layout): def set_value_(self, value, layout=None):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc']
if value is not None and layout == 'csc': if value is not None and layout == 'csc':
value = value[self._arg_csc_to_csr] value = value[self._arg_csc_to_csr]
return self._apply_value_(value) return self._apply_value_(value)
def set_diag(self, value):
raise NotImplementedError
def t(self): def t(self):
storage = SparseStorage( storage = SparseStorage(
self._col[self._arg_csr_to_csc], self._col[self._arg_csr_to_csc],
...@@ -102,22 +121,119 @@ class SparseTensor(object): ...@@ -102,22 +121,119 @@ class SparseTensor(object):
) )
return self.__class__.from_storage(storage) return self.__class__.from_storage(storage)
def matmul(self, mat2):
raise NotImplementedError
def coalesce(self, reduce='add'): def coalesce(self, reduce='add'):
raise NotImplementedError raise NotImplementedError
def is_coalesced(self): def is_coalesced(self):
raise NotImplementedError raise NotImplementedError
def add(self, layout=None): def masked_select(self, mask):
# sub, mul, div raise NotImplementedError
# can take scalars, tensors and other sparse matrices
# inplace variants can only take scalars or tensors def index_select(self, index):
raise NotImplementedError raise NotImplementedError
# TODO: Slicing, (sum|max|min|prod|...), standard operators, masing, perm def select(self, dim, index):
raise NotImplementedError
def filter(self, index):
assert self.is_symmetric
assert index.dtype == torch.long or index.dtype == torch.bool
raise NotImplementedError
def permute(self, index):
assert index.dtype == torch.long
return self.filter(index)
def __getitem__(self, idx):
# Convert int and slice to index tensor
# Filter list into edge and sparse slice
raise NotImplementedError
def __reduce(self, dim, reduce, only_nnz):
raise NotImplementedError
def sum(self, dim):
return self.__reduce(dim, reduce='add', only_nnz=True)
def prod(self, dim):
return self.__reduce(dim, reduce='mul', only_nnz=True)
def min(self, dim, only_nnz=False):
return self.__reduce(dim, reduce='min', only_nnz=only_nnz)
def max(self, dim, only_nnz=False):
return self.__reduce(dim, reduce='min', only_nnz=only_nnz)
def mean(self, dim, only_nnz=False):
return self.__reduce(dim, reduce='mean', only_nnz=only_nnz)
def matmul(self, mat, reduce='add'):
assert self.numel() == self.nnz() # Disallow multi-dimensional value
if torch.is_tensor(mat):
raise NotImplementedError
elif isinstance(mat, self.__class__):
assert reduce == 'add'
assert mat.numel() == mat.nnz() # Disallow multi-dimensional value
raise NotImplementedError
raise ValueError('Argument needs to be of type `torch.tensor` or '
'type `torch_sparse.SparseTensor`.')
def add(self, other, layout=None):
if __is_scalar__(other):
if self.has_value:
return self.set_value(self._value + other, 'coo')
else:
return self.set_value(
torch.full((self.nnz(), ), other + 1), 'coo')
elif torch.is_tensor(other):
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc']
if layout == 'csc':
other = other[self._arg_csc_to_csr]
if self.has_value:
return self.set_value(self._value + other, 'coo')
else:
return self.set_value(other + 1, 'coo')
elif isinstance(other, self.__class__):
raise NotImplementedError
raise ValueError('Argument needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.')
def add_(self, other, layout=None):
if isinstance(other, int) or isinstance(other, float):
raise NotImplementedError
elif torch.is_tensor(other):
raise NotImplementedError
raise ValueError('Argument needs to be a scalar or of type '
'`torch.tensor`.')
def __add__(self, other):
return self.add(other)
def __radd__(self, other):
return self.add(other)
def sub(self, layout=None):
raise NotImplementedError
def sub_(self, layout=None):
raise NotImplementedError
def mul(self, layout=None):
raise NotImplementedError
def mul_(self, layout=None):
raise NotImplementedError
def div(self, layout=None):
raise NotImplementedError
def div_(self, layout=None):
raise NotImplementedError
def to_dense(self, dtype=None): def to_dense(self, dtype=None):
dtype = dtype or self.dtype dtype = dtype or self.dtype
...@@ -125,11 +241,17 @@ class SparseTensor(object): ...@@ -125,11 +241,17 @@ class SparseTensor(object):
mat[self._row, self._col] = self._value if self.has_value else 1 mat[self._row, self._col] = self._value if self.has_value else 1
return mat return mat
def to_scipy(self): def to_scipy(self, layout):
raise NotImplementedError raise NotImplementedError
def to_torch_sparse_coo_tensor(self): def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
raise NotImplementedError index, value = self.coo()
return torch.sparse_coo_tensor(
index,
torch.ones_like(self._row, dtype) if value is None else value,
self.size(),
device=self.device,
requires_grad=requires_grad)
def __repr__(self): def __repr__(self):
i = ' ' * 6 i = ' ' * 6
...@@ -156,7 +278,8 @@ if __name__ == '__main__': ...@@ -156,7 +278,8 @@ if __name__ == '__main__':
device = 'cpu' device = 'cpu'
# dataset = Reddit('/tmp/Reddit') # dataset = Reddit('/tmp/Reddit')
dataset = Planetoid('/tmp/PubMed', 'PubMed') dataset = Planetoid('/tmp/Cora', 'Cora')
# dataset = Planetoid('/tmp/PubMed', 'PubMed')
data = dataset[0].to(device) data = dataset[0].to(device)
_bytes = data.edge_index.numel() * 8 _bytes = data.edge_index.numel() * 8
...@@ -169,8 +292,8 @@ if __name__ == '__main__': ...@@ -169,8 +292,8 @@ if __name__ == '__main__':
print(mat1) print(mat1)
mat1 = mat1.t() mat1 = mat1.t()
mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges), mat2 = torch.sparse_coo_tensor(
device=device) data.edge_index, torch.ones(data.num_edges), device=device)
mat2 = mat2.coalesce() mat2 = mat2.coalesce()
mat2 = mat2.t().coalesce() mat2 = mat2.t().coalesce()
...@@ -182,5 +305,12 @@ if __name__ == '__main__': ...@@ -182,5 +305,12 @@ if __name__ == '__main__':
out2 = mat2.to_dense() out2 = mat2.to_dense()
assert torch.allclose(out1, out2) assert torch.allclose(out1, out2)
mat1 = SparseTensor.from_dense(out1) out = 2 + mat1
print(mat1) print(out)
# mat1[1]
# mat1[1, 1]
# mat1[..., -1]
# mat1[:, -1]
# mat1[1:4, 1:4]
# mat1[torch.tensor([0, 1, 2])]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment