Commit e696cfd6 authored by rusty1s's avatar rusty1s
Browse files

reduce op

parent 5dc4080c
......@@ -14,7 +14,7 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
return torch.tensor(value, device=src.device)
dims = [dim] if isinstance(dim, int) else sorted(list(dim))
assert dim[-1] < src.dim()
assert dims[-1] < src.dim()
rowptr, col, value = src.csr()
......@@ -30,7 +30,7 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
value = src.nnz() if reduce == 'add' else 1
return torch.tensor(value, device=src.device)
if len(dense_dims) > 0 and len(sparse_dims) == 0:
if len(dense_dims) > 0 and len(sparse_dims) == 0: # src.has_value()
func = getattr(torch, 'sum' if reduce == 'add' else reduce)
dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
value = func(value, dim=dense_dims)
......@@ -44,23 +44,43 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
value = func(value, dim=dense_dims)
value = value[0] if isinstance(value, tuple) else value
if sparse_dims[0] == 0:
if sparse_dims[0] == 1 and src.has_value():
out = segment_csr(value, rowptr)
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
if sparse_dims[0] == 1 and (src.storage._csr2csc or deterministic):
if sparse_dims[0] == 1 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
if reduce == 'add':
return src.storage.rowcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
return torch.ones(src.size(0), device=src.device), None
else:
return torch.ones(src.size(0), device=src.device)
deterministic = src.storage._csr2csc is not None or deterministic
if sparse_dims[0] == 0 and deterministic and src.has_value():
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
out = segment_csr(value[csr2csc], colptr)
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
if sparse_dims[0] == 1:
if sparse_dims[0] == 0 and src.has_value():
func = getattr(torch_scatter, f'scatter_{reduce}')
out = func(value, col, dim=0, dim_size=src.sparse_size(0))
out = func(value, col, dim=0, dim_size=src.sparse_size(1))
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
if sparse_dims[0] == 0 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
if reduce == 'add':
return src.storage.colcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
return torch.ones(src.size(1), device=src.device), None
else:
return torch.ones(src.size(1), device=src.device)
def sum(src, dim=None, deterministic=False):
return __reduce__(src, dim, reduce='add', deterministic=deterministic)
......
......@@ -10,13 +10,14 @@ from torch_sparse.narrow import narrow
from torch_sparse.select import select
from torch_sparse.index_select import index_select, index_select_nnz
from torch_sparse.masked_select import masked_select, masked_select_nnz
import torch_sparse.reduce
from torch_sparse.add import add, add_nnz
class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
self.storage = SparseStorage(
index, value, sparse_size, is_sorted=is_sorted)
self.storage = SparseStorage(index, value, sparse_size,
is_sorted=is_sorted)
@classmethod
def from_storage(self, storage):
......@@ -37,8 +38,8 @@ class SparseTensor(object):
@classmethod
def from_torch_sparse_coo_tensor(self, mat, is_sorted=False):
return SparseTensor(
mat._indices(), mat._values(), mat.size()[:2], is_sorted=is_sorted)
return SparseTensor(mat._indices(), mat._values(),
mat.size()[:2], is_sorted=is_sorted)
@classmethod
def from_scipy(self, mat):
......@@ -55,8 +56,8 @@ class SparseTensor(object):
value = torch.from_numpy(mat.data)
size = mat.shape
storage = SparseStorage(
index, value, size, rowptr=rowptr, colptr=colptr, is_sorted=True)
storage = SparseStorage(index, value, size, rowptr=rowptr,
colptr=colptr, is_sorted=True)
return SparseTensor.from_storage(storage)
......@@ -193,8 +194,8 @@ class SparseTensor(object):
return self.from_storage(self.storage.apply(lambda x: x.cpu()))
def cuda(self, device=None, non_blocking=False, **kwargs):
storage = self.storage.apply(lambda x: x.cuda(device, non_blocking, **
kwargs))
storage = self.storage.apply(
lambda x: x.cuda(device, non_blocking, **kwargs))
return self.from_storage(storage)
@property
......@@ -216,8 +217,8 @@ class SparseTensor(object):
if dtype == self.dtype:
return self
storage = self.storage.apply_value(lambda x: x.type(
dtype, non_blocking, **kwargs))
storage = self.storage.apply_value(
lambda x: x.type(dtype, non_blocking, **kwargs))
return self.from_storage(storage)
......@@ -286,12 +287,9 @@ class SparseTensor(object):
def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
index, value = self.coo()
return torch.sparse_coo_tensor(
index,
value if self.has_value() else torch.ones(
self.nnz(), dtype=dtype, device=self.device),
self.size(),
device=self.device,
requires_grad=requires_grad)
index, value if self.has_value() else torch.ones(
self.nnz(), dtype=dtype, device=self.device), self.size(),
device=self.device, requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout=None):
assert self.dim() == 2
......@@ -392,6 +390,10 @@ SparseTensor.index_select = index_select
SparseTensor.index_select_nnz = index_select_nnz
SparseTensor.masked_select = masked_select
SparseTensor.masked_select_nnz = masked_select_nnz
SparseTensor.sum = torch_sparse.reduce.sum
SparseTensor.mean = torch_sparse.reduce.mean
SparseTensor.min = torch_sparse.reduce.min
SparseTensor.max = torch_sparse.reduce.max
SparseTensor.add = add
SparseTensor.add_nnz = add_nnz
......@@ -461,30 +463,38 @@ if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
# dataset = Reddit('/tmp/Reddit')
dataset = Planetoid('/tmp/PubMed', 'PubMed')
dataset = Reddit('/tmp/Reddit')
# dataset = Planetoid('/tmp/PubMed', 'PubMed')
data = dataset[0].to(device)
# value = torch.randn(data.num_edges, 10)
mat = SparseTensor(data.edge_index)
perm = torch.arange(data.num_nodes)
perm = torch.randperm(data.num_nodes)
value = torch.randn((data.num_edges, 8), device=device)
mat = SparseTensor(data.edge_index, value)
print(mat)
mat1 = SparseTensor(torch.tensor([[0, 1], [0, 1]]))
mat2 = SparseTensor(torch.tensor([[0, 0, 1], [0, 1, 0]]))
add(mat1, mat2)
# print(mat2)
raise NotImplementedError
for _ in range(10):
x = torch.randn(1000, 1000, device=device).sum()
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
mat[perm]
torch.cuda.synchronize()
out = mat.sum(dim=1)
torch.cuda.synchronize()
print(time.perf_counter() - t)
print(out.size())
# perm = torch.arange(data.num_nodes)
# perm = torch.randperm(data.num_nodes)
# mat1 = SparseTensor(torch.tensor([[0, 1], [0, 1]]))
# mat2 = SparseTensor(torch.tensor([[0, 0, 1], [0, 1, 0]]))
# add(mat1, mat2)
# # print(mat2)
# raise NotImplementedError
# for _ in range(10):
# x = torch.randn(1000, 1000, device=device).sum()
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# mat[perm]
# torch.cuda.synchronize()
# print(time.perf_counter() - t)
# index = torch.tensor([
# [0, 1, 1, 2, 2],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment