Commit f3b7fb50 authored by rusty1s's avatar rusty1s
Browse files

beginning add

parent d30ed1d5
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.add import add
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_add(dtype, device):
print()
index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
mat1 = SparseTensor(index)
index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
mat2 = SparseTensor(index)
add(mat1, mat2)
import torch
dtypes = [torch.float, torch.double]
dtypes = [torch.float]
devices = [torch.device('cpu')]
if torch.cuda.is_available():
......
import torch
def union(mat1, mat2):
offset = mat1.nnz() + 1
value1 = torch.ones(mat1.nnz(), dtype=torch.long, device=mat2.device)
value2 = value1.new_full((mat2.nnz(), ), offset)
size = max(mat1.size(0), mat2.size(0)), max(mat1.size(1), mat2.size(1))
if not mat1.is_cuda:
mat1 = mat1.set_value(value1, layout='coo').to_scipy(layout='csr')
mat1.resize(*size)
mat2 = mat2.set_value(value2, layout='coo').to_scipy(layout='csr')
mat2.resize(*size)
out = mat1 + mat2
rowptr = torch.from_numpy(out.indptr).to(torch.long)
out = out.tocoo()
row = torch.from_numpy(out.row).to(torch.long)
col = torch.from_numpy(out.col).to(torch.long)
value = torch.from_numpy(out.data)
else:
raise NotImplementedError
mask1 = value % offset > 0
mask2 = value >= offset
return rowptr, torch.stack([row, col], dim=0), mask1, mask2
def add(src, other):
if isinstance(other, int) or isinstance(other, float):
return add_nnz(src, other)
elif torch.is_tensor(other):
(row, col), value = src.coo()
if other.size(0) == src.size(0) and other.size(1) == 1:
val = other.squeeze(1).repeat_interleave(
row, 0) + (value if src.has_value() else 1)
if other.size(0) == 1 and other.size(1) == src.size(1):
val = other.squeeze(0)[col] + (value if src.has_value() else 1)
else:
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
return src.set_value(val, layout='coo')
elif isinstance(other, src.__class__):
rowptr, index, src_offset, other_offset = union(src, other)
raise NotImplementedError
raise ValueError('Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.')
def add_nnz(src, other, layout=None):
if isinstance(other, int) or isinstance(other, float):
return src.set_value(src.storage.value + other if src.has_value(
) else torch.full((src.nnz(), ), 1 + other, device=src.device))
elif torch.is_tensor(other):
return src.set_value(src.storage.value +
other if src.has_value() else other + 1)
raise ValueError('Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.')
......@@ -38,9 +38,17 @@ class SparseStorage(object):
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
]
def __init__(self, index, value=None, sparse_size=None, rowcount=None,
rowptr=None, colcount=None, colptr=None, csr2csc=None,
csc2csr=None, is_sorted=False):
def __init__(self,
index,
value=None,
sparse_size=None,
rowcount=None,
rowptr=None,
colcount=None,
colptr=None,
csr2csc=None,
csc2csr=None,
is_sorted=False):
assert index.dtype == torch.long
assert index.dim() == 2 and index.size(0) == 2
......@@ -130,14 +138,26 @@ class SparseStorage(object):
assert value.size(0) == self._index.size(1)
if value is not None and get_layout(layout) == 'csc':
value = value[self.csc2csr]
return self.apply_value_(lambda x: value)
self._value = value
return self
def set_value(self, value, layout=None):
assert value.device == self._index.device
assert value.size(0) == self._index.size(1)
if value is not None and get_layout(layout) == 'csc':
value = value[self.csc2csr]
return self.apply_value(lambda x: value)
return self.__class__(
self._index,
value,
self._sparse_size,
self._rowcount,
self._rowptr,
self._colcount,
self._colptr,
self._csr2csc,
self._csc2csr,
is_sorted=True,
)
def sparse_size(self, dim=None):
return self._sparse_size if dim is None else self._sparse_size[dim]
......
......@@ -10,12 +10,13 @@ from torch_sparse.narrow import narrow
from torch_sparse.select import select
from torch_sparse.index_select import index_select, index_select_nnz
from torch_sparse.masked_select import masked_select, masked_select_nnz
from torch_sparse.add import add, add_nnz
class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
self.storage = SparseStorage(index, value, sparse_size,
is_sorted=is_sorted)
self.storage = SparseStorage(
index, value, sparse_size, is_sorted=is_sorted)
@classmethod
def from_storage(self, storage):
......@@ -36,8 +37,8 @@ class SparseTensor(object):
@classmethod
def from_torch_sparse_coo_tensor(self, mat, is_sorted=False):
return SparseTensor(mat._indices(), mat._values(),
mat.size()[:2], is_sorted=is_sorted)
return SparseTensor(
mat._indices(), mat._values(), mat.size()[:2], is_sorted=is_sorted)
@classmethod
def from_scipy(self, mat):
......@@ -54,8 +55,8 @@ class SparseTensor(object):
value = torch.from_numpy(mat.data)
size = mat.shape
storage = SparseStorage(index, value, size, rowptr=rowptr,
colptr=colptr, is_sorted=True)
storage = SparseStorage(
index, value, size, rowptr=rowptr, colptr=colptr, is_sorted=True)
return SparseTensor.from_storage(storage)
......@@ -192,8 +193,8 @@ class SparseTensor(object):
return self.from_storage(self.storage.apply(lambda x: x.cpu()))
def cuda(self, device=None, non_blocking=False, **kwargs):
storage = self.storage.apply(
lambda x: x.cuda(device, non_blocking, **kwargs))
storage = self.storage.apply(lambda x: x.cuda(device, non_blocking, **
kwargs))
return self.from_storage(storage)
@property
......@@ -215,8 +216,8 @@ class SparseTensor(object):
if dtype == self.dtype:
return self
storage = self.storage.apply_value(
lambda x: x.type(dtype, non_blocking, **kwargs))
storage = self.storage.apply_value(lambda x: x.type(
dtype, non_blocking, **kwargs))
return self.from_storage(storage)
......@@ -285,9 +286,12 @@ class SparseTensor(object):
def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
index, value = self.coo()
return torch.sparse_coo_tensor(
index, value if self.has_value() else torch.ones(
self.nnz(), dtype=dtype, device=self.device), self.size(),
device=self.device, requires_grad=requires_grad)
index,
value if self.has_value() else torch.ones(
self.nnz(), dtype=dtype, device=self.device),
self.size(),
device=self.device,
requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout=None):
assert self.dim() == 2
......@@ -388,6 +392,8 @@ SparseTensor.index_select = index_select
SparseTensor.index_select_nnz = index_select_nnz
SparseTensor.masked_select = masked_select
SparseTensor.masked_select_nnz = masked_select_nnz
SparseTensor.add = add
SparseTensor.add_nnz = add_nnz
# def remove_diag(self):
# raise NotImplementedError
......@@ -424,40 +430,6 @@ SparseTensor.masked_select_nnz = masked_select_nnz
# raise ValueError('Argument needs to be of type `torch.tensor` or '
# 'type `torch_sparse.SparseTensor`.')
# def add_nnz(self):
# def add(self, other, layout=None):
# if __is_scalar__(other):
# if self.has_value:
# return self.set_value(self._value + other, 'coo')
# else:
# return self.set_value(torch.full((self.nnz(), ), other + 1),
# 'coo')
# elif torch.is_tensor(other):
# if layout is None:
# layout = 'coo'
# warnings.warn('`layout` argument unset, using default layout '
# '"coo". This may lead to unexpected behaviour.')
# assert layout in ['coo', 'csr', 'csc']
# if layout == 'csc':
# other = other[self._arg_csc2csr]
# if self.has_value:
# return self.set_value(self._value + other, 'coo')
# else:
# return self.set_value(other + 1, 'coo')
# elif isinstance(other, self.__class__):
# raise NotImplementedError
# raise ValueError('Argument needs to be of type `int`, `float`, '
# '`torch.tensor` or `torch_sparse.SparseTensor`.')
# def add_(self, other, layout=None):
# if isinstance(other, int) or isinstance(other, float):
# raise NotImplementedError
# elif torch.is_tensor(other):
# raise NotImplementedError
# raise ValueError('Argument needs to be a scalar or of type '
# '`torch.tensor`.')
# def __add__(self, other):
# return self.add(other)
......@@ -498,6 +470,12 @@ if __name__ == '__main__':
perm = torch.arange(data.num_nodes)
perm = torch.randperm(data.num_nodes)
mat1 = SparseTensor(torch.tensor([[0, 1], [0, 1]]))
mat2 = SparseTensor(torch.tensor([[0, 0, 1], [0, 1, 0]]))
add(mat1, mat2)
# print(mat2)
raise NotImplementedError
for _ in range(10):
x = torch.randn(1000, 1000, device=device).sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment