Commit e61e3d45 authored by rusty1s's avatar rusty1s
Browse files

coalesce

parent 7517c965
import warnings
import torch
import torch_scatter
from torch_scatter import scatter_add, segment_add
......@@ -37,17 +38,9 @@ class SparseStorage(object):
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
]
def __init__(self,
index,
value=None,
sparse_size=None,
rowcount=None,
rowptr=None,
colcount=None,
colptr=None,
csr2csc=None,
csc2csr=None,
is_sorted=False):
def __init__(self, index, value=None, sparse_size=None, rowcount=None,
rowptr=None, colcount=None, colptr=None, csr2csc=None,
csc2csr=None, is_sorted=False):
assert index.dtype == torch.long
assert index.dim() == 2 and index.size(0) == 2
......@@ -185,11 +178,27 @@ class SparseStorage(object):
def is_coalesced(self):
idx = self.sparse_size(1) * self.row + self.col
mask = idx == torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
return not mask.any().item()
mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
return mask.all().item()
def coalesce(self):
raise NotImplementedError
def coalesce(self, reduce='add'):
idx = self.sparse_size(1) * self.row + self.col
mask = idx > torch.cat([idx.new_full((1, ), -1), idx[:-1]], dim=0)
if mask.all(): # Already coalesced
return self
index = self.index[:, mask]
value = self.value
if self.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
idx = mask.cumsum(0) - 1
op = getattr(torch_scatter, f'scatter_{reduce}')
value = op(value, idx, dim=0, dim_size=idx[-1].item() + 1)
value = value[0] if isinstance(value, tuple) else value
return self.__class__(index, value, self.sparse_size(), is_sorted=True)
def cached_keys(self):
return [
......
......@@ -14,8 +14,8 @@ from torch_sparse.masked_select import masked_select, masked_select_nnz
class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
self.storage = SparseStorage(
index, value, sparse_size, is_sorted=is_sorted)
self.storage = SparseStorage(index, value, sparse_size,
is_sorted=is_sorted)
@classmethod
def from_storage(self, storage):
......@@ -36,8 +36,8 @@ class SparseTensor(object):
@classmethod
def from_torch_sparse_coo_tensor(self, mat, is_sorted=False):
return SparseTensor(
mat._indices(), mat._values(), mat.size()[:2], is_sorted=is_sorted)
return SparseTensor(mat._indices(), mat._values(),
mat.size()[:2], is_sorted=is_sorted)
@classmethod
def from_scipy(self, mat):
......@@ -54,8 +54,8 @@ class SparseTensor(object):
value = torch.from_numpy(mat.data)
size = mat.shape
storage = SparseStorage(
index, value, size, rowptr=rowptr, colptr=colptr, is_sorted=True)
storage = SparseStorage(index, value, size, rowptr=rowptr,
colptr=colptr, is_sorted=True)
return SparseTensor.from_storage(storage)
......@@ -105,8 +105,8 @@ class SparseTensor(object):
def is_coalesced(self):
return self.storage.is_coalesced()
def coalesce(self):
return self.from_storage(self.storage.coalesce())
def coalesce(self, reduce='add'):
return self.from_storage(self.storage.coalesce(reduce))
def cached_keys(self):
return self.storage.cached_keys()
......@@ -192,8 +192,8 @@ class SparseTensor(object):
return self.from_storage(self.storage.apply(lambda x: x.cpu()))
def cuda(self, device=None, non_blocking=False, **kwargs):
storage = self.storage.apply(lambda x: x.cuda(device, non_blocking, **
kwargs))
storage = self.storage.apply(
lambda x: x.cuda(device, non_blocking, **kwargs))
return self.from_storage(storage)
@property
......@@ -215,8 +215,8 @@ class SparseTensor(object):
if dtype == self.dtype:
return self
storage = self.storage.apply_value(lambda x: x.type(
dtype, non_blocking, **kwargs))
storage = self.storage.apply_value(
lambda x: x.type(dtype, non_blocking, **kwargs))
return self.from_storage(storage)
......@@ -285,12 +285,9 @@ class SparseTensor(object):
def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
index, value = self.coo()
return torch.sparse_coo_tensor(
index,
value if self.has_value() else torch.ones(
self.nnz(), dtype=dtype, device=self.device),
self.size(),
device=self.device,
requires_grad=requires_grad)
index, value if self.has_value() else torch.ones(
self.nnz(), dtype=dtype, device=self.device), self.size(),
device=self.device, requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout=None):
assert self.dim() == 2
......@@ -392,11 +389,6 @@ SparseTensor.index_select_nnz = index_select_nnz
SparseTensor.masked_select = masked_select
SparseTensor.masked_select_nnz = masked_select_nnz
# def __getitem__(self, idx):
# # Convert int and slice to index tensor
# # Filter list into edge and sparse slice
# raise NotImplementedError
# def remove_diag(self):
# raise NotImplementedError
......@@ -503,20 +495,30 @@ if __name__ == '__main__':
value = torch.randn(data.num_edges, 10)
mat = SparseTensor(data.edge_index, value)
index = torch.tensor([0, 1, 2])
mask = torch.zeros(data.num_nodes, dtype=torch.bool)
mask[:3] = True
print(mat[1].size())
print(mat[1, 1].size())
print(mat[..., -1].size())
print(mat[:10, ..., -1].size())
print(mat[:, -1].size())
print(mat[1, :, -1].size())
print(mat[1:4, 1:4].size())
print(mat[index].size())
print(mat[index, index].size())
print(mat[mask, index].size())
index = torch.tensor([
[0, 1, 1, 2, 2],
[1, 2, 2, 2, 3],
])
value = torch.tensor([1, 2, 3, 4, 5])
mat = SparseTensor(index, value)
print(mat)
print(mat.coalesce())
# index = torch.tensor([0, 1, 2])
# mask = torch.zeros(data.num_nodes, dtype=torch.bool)
# mask[:3] = True
# print(mat[1].size())
# print(mat[1, 1].size())
# print(mat[..., -1].size())
# print(mat[:10, ..., -1].size())
# print(mat[:, -1].size())
# print(mat[1, :, -1].size())
# print(mat[1:4, 1:4].size())
# print(mat[index].size())
# print(mat[index, index].size())
# print(mat[mask, index].size())
# mat[::-1]
# mat[::2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment