"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c50d997591d14dfa2030b015d2a5934add658b1d"
Commit bb963e32 authored by rusty1s's avatar rusty1s
Browse files

narrow implementation

parent 195053ef
import torch import torch
from torch_sparse.tensor import SparseTensor
def narrow(src, dim, start, length): def narrow(src, dim, start, length):
if dim == 0: if dim == 0:
col, rowptr, value = src.csr() (row, col), value = src.coo()
rowptr = rowptr.narrow(0, start=start, length=length) rowptr, _, _ = src.csr()
row_start, row_end = rowptr[0] rowptr = rowptr.narrow(0, start=start, length=length + 1)
row_length = rowptr[-1] - row_start row_start = rowptr[0]
rowptr = rowptr - row_start
row_length = rowptr[-1]
row = row.narrow(0, row_start, row_length) - start
col = col.narrow(0, row_start, row_length) col = col.narrow(0, row_start, row_length)
row = self._row.narrow(0, row_start, row_length) index = torch.stack([row, col], dim=0)
if src.has_value():
value = value.narrow(0, row_start, row_length)
sparse_size = torch.Size([length, src.sparse_size(1)])
storage = src._storage.__class__(
index, value, sparse_size, rowptr, is_sorted=True)
elif dim == 1:
# This is faster than accessing `csc()` in analogy to thr `dim=0` case.
(row, col), value = src.coo()
mask = (col >= start) & (col < start + length)
index = torch.stack([row, col - start], dim=0)[:, mask]
if src.has_value():
value = value[mask]
sparse_size = torch.Size([src.sparse_size(0), length])
elif dim == 0: storage = src._storage.__class__(
index, value, sparse_size, is_sorted=True)
else: else:
storage = src._storage.apply_value(lambda x: x.narrow(
dim - 1, start, length))
return src.__class__.from_storage(storage)
pass
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
row = torch.tensor([0, 0, 1, 1], device=device)
col = torch.tensor([1, 2, 0, 2], device=device)
sparse_mat = SparseTensor(torch.stack([row, col], dim=0))
print(sparse_mat)
print(sparse_mat.to_dense())
...@@ -166,6 +166,12 @@ class SparseStorage(object): ...@@ -166,6 +166,12 @@ class SparseStorage(object):
def coalesce(self): def coalesce(self):
raise NotImplementedError raise NotImplementedError
def cached_keys(self):
return [
key for key in self.cache_keys
if getattr(self, f'_{key}', None) is not None
]
def fill_cache_(self, *args): def fill_cache_(self, *args):
for arg in args or self.cache_keys: for arg in args or self.cache_keys:
getattr(self, arg) getattr(self, arg)
...@@ -206,8 +212,8 @@ class SparseStorage(object): ...@@ -206,8 +212,8 @@ class SparseStorage(object):
def apply_(self, func): def apply_(self, func):
self._index = func(self._index) self._index = func(self._index)
self._value = optional(func, self._value) self._value = optional(func, self._value)
for key in self.cache_keys: for key in self.cached_keys():
setattr(self, f'_{key}', optional(func, getattr(self, f'_{key}'))) setattr(self, f'_{key}', func, getattr(self, f'_{key}'))
return self return self
def apply(self, func): def apply(self, func):
...@@ -226,10 +232,7 @@ class SparseStorage(object): ...@@ -226,10 +232,7 @@ class SparseStorage(object):
data = [func(self.index)] data = [func(self.index)]
if self.has_value(): if self.has_value():
data += [func(self.value)] data += [func(self.value)]
data += [ data += [func(getattr(self, f'_{key}')) for key in self.cached_keys()]
func(getattr(self, f'_{key}')) for key in self.cache_keys
if getattr(self, f'_{key}')
]
return data return data
......
...@@ -4,7 +4,9 @@ import torch ...@@ -4,7 +4,9 @@ import torch
import scipy.sparse import scipy.sparse
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
from torch_sparse.transpose import t from torch_sparse.transpose import t
from torch_sparse.narrow import narrow
class SparseTensor(object): class SparseTensor(object):
...@@ -77,8 +79,10 @@ class SparseTensor(object): ...@@ -77,8 +79,10 @@ class SparseTensor(object):
return self._storage.is_coalesced() return self._storage.is_coalesced()
def coalesce(self): def coalesce(self):
storage = self._storage.coalesce() return self.__class__.from_storage(self._storage.coalesce())
return self.__class__.from_storage(storage)
def cached_keys(self):
return self._storage.cached_keys()
def fill_cache_(self, *args): def fill_cache_(self, *args):
self._storage.fill_cache_(*args) self._storage.fill_cache_(*args)
...@@ -139,7 +143,6 @@ class SparseTensor(object): ...@@ -139,7 +143,6 @@ class SparseTensor(object):
def detach(self): def detach(self):
storage = self._storage.apply(lambda x: x.detach()) storage = self._storage.apply(lambda x: x.detach())
print("AWDAwd")
return self.__class__.from_storage(storage) return self.__class__.from_storage(storage)
def pin_memory(self): def pin_memory(self):
...@@ -265,27 +268,29 @@ class SparseTensor(object): ...@@ -265,27 +268,29 @@ class SparseTensor(object):
requires_grad=requires_grad) requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout='coo'): def to_scipy(self, dtype=None, layout='coo'):
assert self.dim() == 2
assert layout in self._storage.layouts assert layout in self._storage.layouts
self = self.detach().cpu() if not self.has_value():
ones = torch.ones(self.nnz(), dtype=dtype).numpy()
if self.has_value():
value = self._storage.value.numpy()
assert value.ndim == 1
else:
value = torch.ones(self.nnz(), dtype=dtype).numpy()
if layout == 'coo': if layout == 'coo':
(row, col), _ = self.coo() (row, col), value = self.coo()
row, col = row.numpy(), col.numpy() row = row.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.coo_matrix((value, (row, col)), self.size()) return scipy.sparse.coo_matrix((value, (row, col)), self.size())
elif layout == 'csr': elif layout == 'csr':
rowptr, col, _ = self.csr() rowptr, col, value = self.csr()
rowptr, col = rowptr.numpy(), col.numpy() rowptr = rowptr.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csr_matrix((value, col, rowptr), self.size()) return scipy.sparse.csr_matrix((value, col, rowptr), self.size())
elif layout == 'csc': elif layout == 'csc':
colptr, row, _ = self.csc() colptr, row, value = self.csc()
colptr, row = colptr.numpy(), row.numpy() colptr = colptr.detach().cpu().numpy()
row = row.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csc_matrix((value, row, colptr), self.size()) return scipy.sparse.csc_matrix((value, row, colptr), self.size())
# String Reputation ####################################################### # String Reputation #######################################################
...@@ -312,6 +317,7 @@ class SparseTensor(object): ...@@ -312,6 +317,7 @@ class SparseTensor(object):
# Bindings #################################################################### # Bindings ####################################################################
SparseTensor.t = t SparseTensor.t = t
SparseTensor.narrow = narrow
# def set_diag(self, value): # def set_diag(self, value):
# raise NotImplementedError # raise NotImplementedError
...@@ -434,33 +440,48 @@ if __name__ == '__main__': ...@@ -434,33 +440,48 @@ if __name__ == '__main__':
dataset = Planetoid('/tmp/Cora', 'Cora') dataset = Planetoid('/tmp/Cora', 'Cora')
data = dataset[0].to(device) data = dataset[0].to(device)
value = torch.ones((data.num_edges, ), device=device) value = torch.randn((data.num_edges, ), device=device)
mat1 = SparseTensor(data.edge_index, value) mat1 = SparseTensor(data.edge_index, value)
print(mat1) # print(mat1)
print(id(mat1))
mat1 = mat1.long() # # print(mat1.to_dense().size())
print(id(mat1)) # print(mat1.to_torch_sparse_coo_tensor().to_dense().size())
mat1 = mat1.long() # print(mat1.to_scipy(layout='coo').todense().shape)
print(id(mat1)) # print(mat1.to_scipy(layout='csr').todense().shape)
mat1 = mat1.to(torch.bool) # print(mat1.to_scipy(layout='csc').todense().shape)
print(mat1)
print(mat1.is_pinned()) # print(mat1.is_quadratic())
# print(mat1.is_symmetric())
print(mat1.to_dense().size())
# print(mat1.cached_keys())
mat2 = mat1.to_torch_sparse_coo_tensor() # mat1 = mat1.t()
print(mat2) # print(mat1.cached_keys())
# mat1 = mat1.t()
print(mat1.to_scipy(layout='coo').todense().shape) # print(mat1.cached_keys())
print(mat1.to_scipy(layout='csr').todense().shape) # print('-------- NARROW ----------')
print(mat1.to_scipy(layout='csc').todense().shape)
t = time.perf_counter()
print(mat1.is_quadratic()) for _ in range(100):
print(mat1.is_symmetric()) out = mat1.narrow(dim=0, start=10, length=10)
# out._storage.colptr
mat1 = mat1.t() print(time.perf_counter() - t)
print(mat1) print(out)
print(out.cached_keys())
t = time.perf_counter()
for _ in range(100):
out = mat1.narrow(dim=1, start=10, length=2000)
# out._storage.colptr
print(time.perf_counter() - t)
print(out)
print(out.cached_keys())
# mat1 = mat1.narrow(0, start=10, length=10)
# mat1._storage._value = torch.randn(mat1.nnz(), 20)
# print(mat1.coo()[1].size())
# mat1 = mat1.narrow(2, start=10, length=10)
# print(mat1.coo()[1].size())
# mat1 = mat1.t() # mat1 = mat1.t()
# mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges), # mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
......
...@@ -29,14 +29,17 @@ def transpose(index, value, m, n, coalesced=True): ...@@ -29,14 +29,17 @@ def transpose(index, value, m, n, coalesced=True):
def t(mat): def t(mat):
((row, col), value), perm = mat.coo(), mat._storage.csr_to_csc (row, col), value = mat.coo()
csr_to_csc = mat._storage.csr_to_csc
storage = mat._storage.__class__( storage = mat._storage.__class__(
index=torch.stack([col, row], dim=0)[:, perm], index=torch.stack([col, row], dim=0)[:, csr_to_csc],
value=value[perm] if mat.has_value() else None, value=value[csr_to_csc] if mat.has_value() else None,
sparse_size=mat.sparse_size()[::-1], sparse_size=mat.sparse_size()[::-1],
rowptr=mat._storage._colptr, rowptr=mat._storage._colptr,
colptr=mat._storage._rowptr, colptr=mat._storage._rowptr,
csr_to_csc=mat._storage._csc_to_csr, csr_to_csc=mat._storage._csc_to_csr,
csc_to_csr=perm, csc_to_csr=csr_to_csc,
is_sorted=True) is_sorted=True)
return mat.__class__.from_storage(storage) return mat.__class__.from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment