Commit fcf88282 authored by rusty1s's avatar rusty1s
Browse files

tensor fixes

parent 918b1163
......@@ -37,6 +37,7 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
// TODO: Make more efficient.
if (thread_idx < numel) {
int64_t idx = ptr_data[thread_idx], next_idx = ptr_data[thread_idx + 1];
for (int64_t i = idx; i < next_idx; i++) {
......
......@@ -99,7 +99,6 @@ class SparseStorage(object):
if value is not None:
assert value.device == col.device
assert value.size(0) == col.size(0)
value = value.contiguous()
if rowcount is not None:
assert rowcount.dtype == torch.long
......@@ -160,7 +159,7 @@ class SparseStorage(object):
def row(self):
if self._row is None:
func = convert_cuda if self.rowptr.is_cuda else convert_cpu
self._row = func.ptr2ind(self.rowptr, self.nnz())
self._row = func.ptr2ind(self.rowptr, self.col.numel())
return self._row
def has_rowptr(self):
......@@ -184,9 +183,9 @@ class SparseStorage(object):
def value(self):
return self._value
def set_value_(self, value, dtype=None, layout=None):
def set_value_(self, value, layout=None, dtype=None):
if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.nnz(), ), dtype=dtype,
value = torch.full((self.col.numel(), ), dtype=dtype,
device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
......@@ -200,9 +199,9 @@ class SparseStorage(object):
self._value = value
return self
def set_value(self, value, dtype=None, layout=None):
def set_value(self, value, layout=None, dtype=None):
if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.nnz(), ), dtype=dtype,
value = torch.full((self.col.numel(), ), dtype=dtype,
device=self.col.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
......@@ -224,7 +223,7 @@ class SparseStorage(object):
return self._sparse_size
def sparse_resize(self, *sizes):
old_sparse_size, nnz = self.sparse_size, self.nnz()
old_sparse_size, nnz = self.sparse_size, self.col.numel()
diff_0 = sizes[0] - old_sparse_size[0]
rowcount, rowptr = self._rowcount, self._rowptr
......@@ -258,9 +257,6 @@ class SparseStorage(object):
colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def nnz(self):
return self.col.numel()
def has_rowcount(self):
return self._rowcount is not None
......
......@@ -13,11 +13,14 @@ from torch_sparse.masked_select import masked_select, masked_select_nnz
import torch_sparse.reduce
from torch_sparse.diag import remove_diag
from torch_sparse.matmul import matmul
from torch_sparse.add import add, add_, add_nnz, add_nnz_
class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
self.storage = SparseStorage(index, value, sparse_size,
def __init__(self, row=None, rowptr=None, col=None, value=None,
sparse_size=None, is_sorted=False):
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
value=value, sparse_size=sparse_size,
is_sorted=is_sorted)
@classmethod
......@@ -33,14 +36,15 @@ class SparseTensor(object):
else:
index = mat.nonzero()
index = index.t().contiguous()
value = mat[index[0], index[1]]
return SparseTensor(index, value, mat.size()[:2], is_sorted=True)
row, col = index.t().contiguous()
return SparseTensor(row=row, col=col, value=mat[row, col],
sparse_size=mat.size()[:2], is_sorted=True)
@classmethod
def from_torch_sparse_coo_tensor(self, mat, is_sorted=False):
return SparseTensor(mat._indices(), mat._values(),
mat.size()[:2], is_sorted=is_sorted)
row, col = mat._indices()
return SparseTensor(row=row, col=col, value=mat._values(),
sparse_size=mat.size()[:2], is_sorted=is_sorted)
@classmethod
def from_scipy(self, mat):
......@@ -48,60 +52,52 @@ class SparseTensor(object):
if isinstance(mat, scipy.sparse.csc_matrix):
colptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocsr()
mat = mat.tocsr() # Pre-sort.
rowptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocoo()
row = torch.from_numpy(mat.row).to(torch.long)
col = torch.from_numpy(mat.col).to(torch.long)
index = torch.stack([row, col], dim=0)
value = torch.from_numpy(mat.data)
size = mat.shape
sparse_size = mat.shape[:2]
storage = SparseStorage(index, value, size, rowptr=rowptr,
colptr=colptr, is_sorted=True)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_size=sparse_size, colptr=colptr,
is_sorted=True)
return SparseTensor.from_storage(storage)
@classmethod
def eye(self, M, N=None, device=None, dtype=None, no_value=False,
def eye(self, M, N=None, device=None, dtype=None, has_value=True,
fill_cache=False):
N = M if N is None else N
index = torch.empty((2, min(M, N)), dtype=torch.long, device=device)
torch.arange(index.size(1), out=index[0])
torch.arange(index.size(1), out=index[1])
row = torch.arange(min(M, N), device=device)
rowptr = torch.arange(M + 1, device=device)
if M > N:
rowptr[row.size(0) + 1:] = row.size(0)
col = row
value = None
if not no_value:
value = torch.ones(index.size(1), dtype=dtype, device=device)
if has_value:
value = torch.ones(row.size(0), dtype=dtype, device=device)
rowcount = rowptr = colcount = colptr = csr2csc = csc2csr = None
rowcount = colptr = colcount = csr2csc = csc2csr = None
if fill_cache:
rowcount = index.new_ones(M)
rowptr = torch.arange(M + 1, device=device)
rowcount = row.new_ones(M)
if M > N:
rowcount[index.size(1):] = 0
rowptr[index.size(1) + 1:] = index.size(1)
colcount = index.new_ones(N)
rowcount[row.size(0):] = 0
colptr = torch.arange(N + 1, device=device)
colcount = col.new_ones(N)
if N > M:
colcount[index.size(1):] = 0
colptr[index.size(1) + 1:] = index.size(1)
csr2csc = torch.arange(index.size(1), device=device)
csc2csr = torch.arange(index.size(1), device=device)
storage = SparseStorage(
index,
value,
torch.Size([M, N]),
rowcount=rowcount,
rowptr=rowptr,
colcount=colcount,
colptr=colptr,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
)
colptr[col.size(0) + 1:] = col.size(0)
colcount[col.size(0):] = 0
csr2csc = csc2csr = row
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_size=torch.Size([M, N]),
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return SparseTensor.from_storage(storage)
def __copy__(self):
......@@ -118,7 +114,7 @@ class SparseTensor(object):
# Formats #################################################################
def coo(self):
return self.storage.index, self.storage.value
return self.storage.row, self.storage.col, self.storage.value
def csr(self):
return self.storage.rowptr, self.storage.col, self.storage.value
......@@ -133,15 +129,16 @@ class SparseTensor(object):
def has_value(self):
return self.storage.has_value()
def set_value_(self, value, layout=None):
self.storage.set_value_(value, layout)
def set_value_(self, value, layout=None, dtype=None):
self.storage.set_value_(value, layout, dtype)
return self
def set_value(self, value, layout=None):
return self.from_storage(self.storage.set_value(value, layout))
def set_value(self, value, layout=None, dtype=None):
return self.from_storage(self.storage.set_value(value, layout, dtype))
def sparse_size(self, dim=None):
return self.storage.sparse_size(dim)
sparse_size = self.storage.sparse_size
return sparse_size if dim is None else sparse_size[dim]
def sparse_resize(self, *sizes):
return self.from_storage(self.storage.sparse_resize(*sizes))
......@@ -165,20 +162,20 @@ class SparseTensor(object):
# Utility functions #######################################################
def dim(self):
return len(self.size())
def size(self, dim=None):
size = self.sparse_size()
size += self.storage.value.size()[1:] if self.has_value() else ()
return size if dim is None else size[dim]
def dim(self):
return len(self.size())
@property
def shape(self):
return self.size()
def nnz(self):
return self.storage.nnz()
return self.storage.col.numel()
def density(self):
return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))
......@@ -202,11 +199,16 @@ class SparseTensor(object):
if not self.is_quadratic:
return False
rowptr, col, val1 = self.csr()
colptr, row, val2 = self.csc()
index_sym = (rowptr == colptr).all() and (col == row).all()
value_sym = (val1 == val2).all().item() if self.has_value() else True
return index_sym.item() and value_sym
rowptr, col, value1 = self.csr()
colptr, row, value2 = self.csc()
if (rowptr != colptr).any() or (col != row).any():
return False
if not self.has_value():
return True
return (value1 == value2).all().item()
def detach_(self):
self.storage.apply_(lambda x: x.detach_())
......@@ -219,9 +221,13 @@ class SparseTensor(object):
def requires_grad(self):
return self.storage.value.requires_grad if self.has_value() else False
def requires_grad_(self, requires_grad=True):
def requires_grad_(self, requires_grad=True, dtype=None):
if requires_grad and not self.has_value():
self.storage.set_value_(1, dtype=dtype)
if self.has_value():
self.storage.value.requires_grad_(requires_grad)
return self
def pin_memory(self):
......@@ -239,7 +245,7 @@ class SparseTensor(object):
@property
def device(self):
return self.storage.index.device
return self.storage.col.device
def cpu(self):
return self.from_storage(self.storage.apply(lambda x: x.cpu()))
......@@ -251,7 +257,7 @@ class SparseTensor(object):
@property
def is_cuda(self):
return self.storage.index.is_cuda
return self.storage.col.is_cuda
@property
def dtype(self):
......@@ -296,7 +302,7 @@ class SparseTensor(object):
if len(args) > 0 or len(kwargs) > 0:
storage = storage.apply_value(lambda x: x.type(*args, **kwargs))
if storage == self.storage: # Nothing changed...
if storage == self.storage: # Nothing has been changed...
return self
else:
return self.from_storage(storage)
......@@ -335,19 +341,21 @@ class SparseTensor(object):
def to_dense(self, dtype=None):
dtype = dtype or self.dtype
(row, col), value = self.coo()
row, col, value = self.coo()
mat = torch.zeros(self.size(), dtype=dtype, device=self.device)
mat[row, col] = value if self.has_value() else 1
return mat
def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
index, value = self.coo()
return torch.sparse_coo_tensor(
index, value if self.has_value() else torch.ones(
self.nnz(), dtype=dtype, device=self.device), self.size(),
device=self.device, requires_grad=requires_grad)
row, col, value = self.coo()
index = torch.stack([row, col], dim=0)
if value is None:
value = torch.ones(self.nnz(), dtype=dtype, device=self.device)
return torch.sparse_coo_tensor(index, value, self.size(),
device=self.device,
requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout=None):
def to_scipy(self, dtype=None, layout="csr"):
assert self.dim() == 2
layout = get_layout(layout)
......@@ -355,7 +363,7 @@ class SparseTensor(object):
ones = torch.ones(self.nnz(), dtype=dtype).numpy()
if layout == 'coo':
(row, col), value = self.coo()
row, col, value = self.coo()
row = row.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
......@@ -379,7 +387,7 @@ class SparseTensor(object):
index = list(index) if isinstance(index, tuple) else [index]
# More than one `Ellipsis` is not allowed...
if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
raise SyntaxError()
raise SyntaxError
dim = 0
out = self
......@@ -416,6 +424,15 @@ class SparseTensor(object):
return out
def __add__(self, other):
return self.add(other)
def __radd__(self, other):
return self.add(other)
def __iadd__(self, other):
return self.add_(other)
def __matmul__(a, b):
return matmul(a, b, reduce='sum')
......@@ -423,8 +440,10 @@ class SparseTensor(object):
def __repr__(self):
i = ' ' * 6
index, value = self.coo()
infos = [f'index={indent(index.__repr__(), i)[len(i):]}']
row, col, value = self.coo()
infos = []
infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
if self.has_value():
infos += [f'value={indent(value.__repr__(), i)[len(i):]}']
......@@ -456,8 +475,10 @@ SparseTensor.min = torch_sparse.reduce.min
SparseTensor.max = torch_sparse.reduce.max
SparseTensor.remove_diag = remove_diag
SparseTensor.matmul = matmul
# SparseTensor.add = add
# SparseTensor.add_nnz = add_nnz
SparseTensor.add = add
SparseTensor.add_ = add_
SparseTensor.add_nnz = add_nnz
SparseTensor.add_nnz_ = add_nnz_
# def __add__(self, other):
# return self.add(other)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment