Commit e61e3d45 authored by rusty1s's avatar rusty1s
Browse files

coalesce

parent 7517c965
import warnings import warnings
import torch import torch
import torch_scatter
from torch_scatter import scatter_add, segment_add from torch_scatter import scatter_add, segment_add
...@@ -37,17 +38,9 @@ class SparseStorage(object): ...@@ -37,17 +38,9 @@ class SparseStorage(object):
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr' 'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
] ]
def __init__(self, def __init__(self, index, value=None, sparse_size=None, rowcount=None,
index, rowptr=None, colcount=None, colptr=None, csr2csc=None,
value=None, csc2csr=None, is_sorted=False):
sparse_size=None,
rowcount=None,
rowptr=None,
colcount=None,
colptr=None,
csr2csc=None,
csc2csr=None,
is_sorted=False):
assert index.dtype == torch.long assert index.dtype == torch.long
assert index.dim() == 2 and index.size(0) == 2 assert index.dim() == 2 and index.size(0) == 2
...@@ -185,11 +178,27 @@ class SparseStorage(object): ...@@ -185,11 +178,27 @@ class SparseStorage(object):
def is_coalesced(self): def is_coalesced(self):
idx = self.sparse_size(1) * self.row + self.col idx = self.sparse_size(1) * self.row + self.col
mask = idx == torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0) mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
return not mask.any().item() return mask.all().item()
def coalesce(self): def coalesce(self, reduce='add'):
raise NotImplementedError idx = self.sparse_size(1) * self.row + self.col
mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
if mask.all(): # Already coalesced
return self
index = self.index[:, mask]
value = self.value
if self.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
idx = mask.cumsum(0) - 1
op = getattr(torch_scatter, f'scatter_{reduce}')
value = op(value, idx, dim=0, dim_size=idx[-1].item() + 1)
value = value[0] if isinstance(value, tuple) else value
return self.__class__(index, value, self.sparse_size(), is_sorted=True)
def cached_keys(self): def cached_keys(self):
return [ return [
......
...@@ -14,8 +14,8 @@ from torch_sparse.masked_select import masked_select, masked_select_nnz ...@@ -14,8 +14,8 @@ from torch_sparse.masked_select import masked_select, masked_select_nnz
class SparseTensor(object): class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False): def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
self.storage = SparseStorage( self.storage = SparseStorage(index, value, sparse_size,
index, value, sparse_size, is_sorted=is_sorted) is_sorted=is_sorted)
@classmethod @classmethod
def from_storage(self, storage): def from_storage(self, storage):
...@@ -36,8 +36,8 @@ class SparseTensor(object): ...@@ -36,8 +36,8 @@ class SparseTensor(object):
@classmethod @classmethod
def from_torch_sparse_coo_tensor(self, mat, is_sorted=False): def from_torch_sparse_coo_tensor(self, mat, is_sorted=False):
return SparseTensor( return SparseTensor(mat._indices(), mat._values(),
mat._indices(), mat._values(), mat.size()[:2], is_sorted=is_sorted) mat.size()[:2], is_sorted=is_sorted)
@classmethod @classmethod
def from_scipy(self, mat): def from_scipy(self, mat):
...@@ -54,8 +54,8 @@ class SparseTensor(object): ...@@ -54,8 +54,8 @@ class SparseTensor(object):
value = torch.from_numpy(mat.data) value = torch.from_numpy(mat.data)
size = mat.shape size = mat.shape
storage = SparseStorage( storage = SparseStorage(index, value, size, rowptr=rowptr,
index, value, size, rowptr=rowptr, colptr=colptr, is_sorted=True) colptr=colptr, is_sorted=True)
return SparseTensor.from_storage(storage) return SparseTensor.from_storage(storage)
...@@ -105,8 +105,8 @@ class SparseTensor(object): ...@@ -105,8 +105,8 @@ class SparseTensor(object):
def is_coalesced(self): def is_coalesced(self):
return self.storage.is_coalesced() return self.storage.is_coalesced()
def coalesce(self): def coalesce(self, reduce='add'):
return self.from_storage(self.storage.coalesce()) return self.from_storage(self.storage.coalesce(reduce))
def cached_keys(self): def cached_keys(self):
return self.storage.cached_keys() return self.storage.cached_keys()
...@@ -192,8 +192,8 @@ class SparseTensor(object): ...@@ -192,8 +192,8 @@ class SparseTensor(object):
return self.from_storage(self.storage.apply(lambda x: x.cpu())) return self.from_storage(self.storage.apply(lambda x: x.cpu()))
def cuda(self, device=None, non_blocking=False, **kwargs): def cuda(self, device=None, non_blocking=False, **kwargs):
storage = self.storage.apply(lambda x: x.cuda(device, non_blocking, ** storage = self.storage.apply(
kwargs)) lambda x: x.cuda(device, non_blocking, **kwargs))
return self.from_storage(storage) return self.from_storage(storage)
@property @property
...@@ -215,8 +215,8 @@ class SparseTensor(object): ...@@ -215,8 +215,8 @@ class SparseTensor(object):
if dtype == self.dtype: if dtype == self.dtype:
return self return self
storage = self.storage.apply_value(lambda x: x.type( storage = self.storage.apply_value(
dtype, non_blocking, **kwargs)) lambda x: x.type(dtype, non_blocking, **kwargs))
return self.from_storage(storage) return self.from_storage(storage)
...@@ -285,12 +285,9 @@ class SparseTensor(object): ...@@ -285,12 +285,9 @@ class SparseTensor(object):
def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False): def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
index, value = self.coo() index, value = self.coo()
return torch.sparse_coo_tensor( return torch.sparse_coo_tensor(
index, index, value if self.has_value() else torch.ones(
value if self.has_value() else torch.ones( self.nnz(), dtype=dtype, device=self.device), self.size(),
self.nnz(), dtype=dtype, device=self.device), device=self.device, requires_grad=requires_grad)
self.size(),
device=self.device,
requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout=None): def to_scipy(self, dtype=None, layout=None):
assert self.dim() == 2 assert self.dim() == 2
...@@ -392,11 +389,6 @@ SparseTensor.index_select_nnz = index_select_nnz ...@@ -392,11 +389,6 @@ SparseTensor.index_select_nnz = index_select_nnz
SparseTensor.masked_select = masked_select SparseTensor.masked_select = masked_select
SparseTensor.masked_select_nnz = masked_select_nnz SparseTensor.masked_select_nnz = masked_select_nnz
# def __getitem__(self, idx):
# # Convert int and slice to index tensor
# # Filter list into edge and sparse slice
# raise NotImplementedError
# def remove_diag(self): # def remove_diag(self):
# raise NotImplementedError # raise NotImplementedError
...@@ -503,20 +495,30 @@ if __name__ == '__main__': ...@@ -503,20 +495,30 @@ if __name__ == '__main__':
value = torch.randn(data.num_edges, 10) value = torch.randn(data.num_edges, 10)
mat = SparseTensor(data.edge_index, value) mat = SparseTensor(data.edge_index, value)
index = torch.tensor([0, 1, 2]) index = torch.tensor([
mask = torch.zeros(data.num_nodes, dtype=torch.bool) [0, 1, 1, 2, 2],
mask[:3] = True [1, 2, 2, 2, 3],
])
print(mat[1].size()) value = torch.tensor([1, 2, 3, 4, 5])
print(mat[1, 1].size())
print(mat[..., -1].size()) mat = SparseTensor(index, value)
print(mat[:10, ..., -1].size()) print(mat)
print(mat[:, -1].size()) print(mat.coalesce())
print(mat[1, :, -1].size())
print(mat[1:4, 1:4].size()) # index = torch.tensor([0, 1, 2])
print(mat[index].size()) # mask = torch.zeros(data.num_nodes, dtype=torch.bool)
print(mat[index, index].size()) # mask[:3] = True
print(mat[mask, index].size())
# print(mat[1].size())
# print(mat[1, 1].size())
# print(mat[..., -1].size())
# print(mat[:10, ..., -1].size())
# print(mat[:, -1].size())
# print(mat[1, :, -1].size())
# print(mat[1:4, 1:4].size())
# print(mat[index].size())
# print(mat[index, index].size())
# print(mat[mask, index].size())
# mat[::-1] # mat[::-1]
# mat[::2] # mat[::2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment