Commit bb963e32 authored by rusty1s's avatar rusty1s
Browse files

narrow implementation

parent 195053ef
import torch
from torch_sparse.tensor import SparseTensor
def narrow(src, dim, start, length):
if dim == 0:
col, rowptr, value = src.csr()
rowptr = rowptr.narrow(0, start=start, length=length)
if dim == 0:
(row, col), value = src.coo()
rowptr, _, _ = src.csr()
row_start, row_end = rowptr[0]
row_length = rowptr[-1] - row_start
rowptr = rowptr.narrow(0, start=start, length=length + 1)
row_start = rowptr[0]
rowptr = rowptr - row_start
row_length = rowptr[-1]
row = row.narrow(0, row_start, row_length) - start
col = col.narrow(0, row_start, row_length)
row = self._row.narrow(0, row_start, row_length)
index = torch.stack([row, col], dim=0)
if src.has_value():
value = value.narrow(0, row_start, row_length)
sparse_size = torch.Size([length, src.sparse_size(1)])
storage = src._storage.__class__(
index, value, sparse_size, rowptr, is_sorted=True)
elif dim == 1:
# This is faster than accessing `csc()` in analogy to thr `dim=0` case.
(row, col), value = src.coo()
mask = (col >= start) & (col < start + length)
index = torch.stack([row, col - start], dim=0)[:, mask]
if src.has_value():
value = value[mask]
sparse_size = torch.Size([src.sparse_size(0), length])
elif dim == 0:
storage = src._storage.__class__(
index, value, sparse_size, is_sorted=True)
else:
storage = src._storage.apply_value(lambda x: x.narrow(
dim - 1, start, length))
pass
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
row = torch.tensor([0, 0, 1, 1], device=device)
col = torch.tensor([1, 2, 0, 2], device=device)
sparse_mat = SparseTensor(torch.stack([row, col], dim=0))
print(sparse_mat)
print(sparse_mat.to_dense())
return src.__class__.from_storage(storage)
......@@ -166,6 +166,12 @@ class SparseStorage(object):
def coalesce(self):
raise NotImplementedError
def cached_keys(self):
return [
key for key in self.cache_keys
if getattr(self, f'_{key}', None) is not None
]
def fill_cache_(self, *args):
for arg in args or self.cache_keys:
getattr(self, arg)
......@@ -206,8 +212,8 @@ class SparseStorage(object):
def apply_(self, func):
self._index = func(self._index)
self._value = optional(func, self._value)
for key in self.cache_keys:
setattr(self, f'_{key}', optional(func, getattr(self, f'_{key}')))
for key in self.cached_keys():
setattr(self, f'_{key}', func, getattr(self, f'_{key}'))
return self
def apply(self, func):
......@@ -226,10 +232,7 @@ class SparseStorage(object):
data = [func(self.index)]
if self.has_value():
data += [func(self.value)]
data += [
func(getattr(self, f'_{key}')) for key in self.cache_keys
if getattr(self, f'_{key}')
]
data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
return data
......
......@@ -4,7 +4,9 @@ import torch
import scipy.sparse
from torch_sparse.storage import SparseStorage
from torch_sparse.transpose import t
from torch_sparse.narrow import narrow
class SparseTensor(object):
......@@ -77,8 +79,10 @@ class SparseTensor(object):
return self._storage.is_coalesced()
def coalesce(self):
storage = self._storage.coalesce()
return self.__class__.from_storage(storage)
return self.__class__.from_storage(self._storage.coalesce())
def cached_keys(self):
return self._storage.cached_keys()
def fill_cache_(self, *args):
self._storage.fill_cache_(*args)
......@@ -139,7 +143,6 @@ class SparseTensor(object):
def detach(self):
storage = self._storage.apply(lambda x: x.detach())
print("AWDAwd")
return self.__class__.from_storage(storage)
def pin_memory(self):
......@@ -265,27 +268,29 @@ class SparseTensor(object):
requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout='coo'):
assert self.dim() == 2
assert layout in self._storage.layouts
self = self.detach().cpu()
if self.has_value():
value = self._storage.value.numpy()
assert value.ndim == 1
else:
value = torch.ones(self.nnz(), dtype=dtype).numpy()
if not self.has_value():
ones = torch.ones(self.nnz(), dtype=dtype).numpy()
if layout == 'coo':
(row, col), _ = self.coo()
row, col = row.numpy(), col.numpy()
(row, col), value = self.coo()
row = row.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.coo_matrix((value, (row, col)), self.size())
elif layout == 'csr':
rowptr, col, _ = self.csr()
rowptr, col = rowptr.numpy(), col.numpy()
rowptr, col, value = self.csr()
rowptr = rowptr.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csr_matrix((value, col, rowptr), self.size())
elif layout == 'csc':
colptr, row, _ = self.csc()
colptr, row = colptr.numpy(), row.numpy()
colptr, row, value = self.csc()
colptr = colptr.detach().cpu().numpy()
row = row.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csc_matrix((value, row, colptr), self.size())
# String Reputation #######################################################
......@@ -312,6 +317,7 @@ class SparseTensor(object):
# Bindings ####################################################################
SparseTensor.t = t
SparseTensor.narrow = narrow
# def set_diag(self, value):
# raise NotImplementedError
......@@ -434,33 +440,48 @@ if __name__ == '__main__':
dataset = Planetoid('/tmp/Cora', 'Cora')
data = dataset[0].to(device)
value = torch.ones((data.num_edges, ), device=device)
value = torch.randn((data.num_edges, ), device=device)
mat1 = SparseTensor(data.edge_index, value)
print(mat1)
print(id(mat1))
mat1 = mat1.long()
print(id(mat1))
mat1 = mat1.long()
print(id(mat1))
mat1 = mat1.to(torch.bool)
print(mat1)
print(mat1.is_pinned())
print(mat1.to_dense().size())
mat2 = mat1.to_torch_sparse_coo_tensor()
print(mat2)
print(mat1.to_scipy(layout='coo').todense().shape)
print(mat1.to_scipy(layout='csr').todense().shape)
print(mat1.to_scipy(layout='csc').todense().shape)
print(mat1.is_quadratic())
print(mat1.is_symmetric())
mat1 = mat1.t()
print(mat1)
# print(mat1)
# # print(mat1.to_dense().size())
# print(mat1.to_torch_sparse_coo_tensor().to_dense().size())
# print(mat1.to_scipy(layout='coo').todense().shape)
# print(mat1.to_scipy(layout='csr').todense().shape)
# print(mat1.to_scipy(layout='csc').todense().shape)
# print(mat1.is_quadratic())
# print(mat1.is_symmetric())
# print(mat1.cached_keys())
# mat1 = mat1.t()
# print(mat1.cached_keys())
# mat1 = mat1.t()
# print(mat1.cached_keys())
# print('-------- NARROW ----------')
t = time.perf_counter()
for _ in range(100):
out = mat1.narrow(dim=0, start=10, length=10)
# out._storage.colptr
print(time.perf_counter() - t)
print(out)
print(out.cached_keys())
t = time.perf_counter()
for _ in range(100):
out = mat1.narrow(dim=1, start=10, length=2000)
# out._storage.colptr
print(time.perf_counter() - t)
print(out)
print(out.cached_keys())
# mat1 = mat1.narrow(0, start=10, length=10)
# mat1._storage._value = torch.randn(mat1.nnz(), 20)
# print(mat1.coo()[1].size())
# mat1 = mat1.narrow(2, start=10, length=10)
# print(mat1.coo()[1].size())
# mat1 = mat1.t()
# mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
......
......@@ -29,14 +29,17 @@ def transpose(index, value, m, n, coalesced=True):
def t(mat):
((row, col), value), perm = mat.coo(), mat._storage.csr_to_csc
(row, col), value = mat.coo()
csr_to_csc = mat._storage.csr_to_csc
storage = mat._storage.__class__(
index=torch.stack([col, row], dim=0)[:, perm],
value=value[perm] if mat.has_value() else None,
index=torch.stack([col, row], dim=0)[:, csr_to_csc],
value=value[csr_to_csc] if mat.has_value() else None,
sparse_size=mat.sparse_size()[::-1],
rowptr=mat._storage._colptr,
colptr=mat._storage._rowptr,
csr_to_csc=mat._storage._csc_to_csr,
csc_to_csr=perm,
csc_to_csr=csr_to_csc,
is_sorted=True)
return mat.__class__.from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment