"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "9ccebe5b50823e638322e84c65a3526ea5d684b5"
Commit e696cfd6 authored by rusty1s's avatar rusty1s
Browse files

reduce op

parent 5dc4080c
...@@ -14,7 +14,7 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False): ...@@ -14,7 +14,7 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
return torch.tensor(value, device=src.device) return torch.tensor(value, device=src.device)
dims = [dim] if isinstance(dim, int) else sorted(list(dim)) dims = [dim] if isinstance(dim, int) else sorted(list(dim))
assert dim[-1] < src.dim() assert dims[-1] < src.dim()
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
...@@ -30,7 +30,7 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False): ...@@ -30,7 +30,7 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
value = src.nnz() if reduce == 'add' else 1 value = src.nnz() if reduce == 'add' else 1
return torch.tensor(value, device=src.device) return torch.tensor(value, device=src.device)
if len(dense_dims) > 0 and len(sparse_dims) == 0: if len(dense_dims) > 0 and len(sparse_dims) == 0: # src.has_value()
func = getattr(torch, 'sum' if reduce == 'add' else reduce) func = getattr(torch, 'sum' if reduce == 'add' else reduce)
dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
value = func(value, dim=dense_dims) value = func(value, dim=dense_dims)
...@@ -44,23 +44,43 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False): ...@@ -44,23 +44,43 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
value = func(value, dim=dense_dims) value = func(value, dim=dense_dims)
value = value[0] if isinstance(value, tuple) else value value = value[0] if isinstance(value, tuple) else value
if sparse_dims[0] == 0: if sparse_dims[0] == 1 and src.has_value():
out = segment_csr(value, rowptr) out = segment_csr(value, rowptr)
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out return out
if sparse_dims[0] == 1 and (src.storage._csr2csc or deterministic): if sparse_dims[0] == 1 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
if reduce == 'add':
return src.storage.rowcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
return torch.ones(src.size(0), device=src.device), None
else:
return torch.ones(src.size(0), device=src.device)
deterministic = src.storage._csr2csc is not None or deterministic
if sparse_dims[0] == 0 and deterministic and src.has_value():
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
out = segment_csr(value[csr2csc], colptr) out = segment_csr(value[csr2csc], colptr)
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out return out
if sparse_dims[0] == 1: if sparse_dims[0] == 0 and src.has_value():
func = getattr(torch_scatter, f'scatter_{reduce}') func = getattr(torch_scatter, f'scatter_{reduce}')
out = func(value, col, dim=0, dim_size=src.sparse_size(0)) out = func(value, col, dim=0, dim_size=src.sparse_size(1))
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out return out
if sparse_dims[0] == 0 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
if reduce == 'add':
return src.storage.colcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
return torch.ones(src.size(1), device=src.device), None
else:
return torch.ones(src.size(1), device=src.device)
def sum(src, dim=None, deterministic=False): def sum(src, dim=None, deterministic=False):
return __reduce__(src, dim, reduce='add', deterministic=deterministic) return __reduce__(src, dim, reduce='add', deterministic=deterministic)
......
...@@ -10,13 +10,14 @@ from torch_sparse.narrow import narrow ...@@ -10,13 +10,14 @@ from torch_sparse.narrow import narrow
from torch_sparse.select import select from torch_sparse.select import select
from torch_sparse.index_select import index_select, index_select_nnz from torch_sparse.index_select import index_select, index_select_nnz
from torch_sparse.masked_select import masked_select, masked_select_nnz from torch_sparse.masked_select import masked_select, masked_select_nnz
import torch_sparse.reduce
from torch_sparse.add import add, add_nnz from torch_sparse.add import add, add_nnz
class SparseTensor(object): class SparseTensor(object):
def __init__(self, index, value=None, sparse_size=None, is_sorted=False): def __init__(self, index, value=None, sparse_size=None, is_sorted=False):
self.storage = SparseStorage( self.storage = SparseStorage(index, value, sparse_size,
index, value, sparse_size, is_sorted=is_sorted) is_sorted=is_sorted)
@classmethod @classmethod
def from_storage(self, storage): def from_storage(self, storage):
...@@ -37,8 +38,8 @@ class SparseTensor(object): ...@@ -37,8 +38,8 @@ class SparseTensor(object):
@classmethod @classmethod
def from_torch_sparse_coo_tensor(self, mat, is_sorted=False): def from_torch_sparse_coo_tensor(self, mat, is_sorted=False):
return SparseTensor( return SparseTensor(mat._indices(), mat._values(),
mat._indices(), mat._values(), mat.size()[:2], is_sorted=is_sorted) mat.size()[:2], is_sorted=is_sorted)
@classmethod @classmethod
def from_scipy(self, mat): def from_scipy(self, mat):
...@@ -55,8 +56,8 @@ class SparseTensor(object): ...@@ -55,8 +56,8 @@ class SparseTensor(object):
value = torch.from_numpy(mat.data) value = torch.from_numpy(mat.data)
size = mat.shape size = mat.shape
storage = SparseStorage( storage = SparseStorage(index, value, size, rowptr=rowptr,
index, value, size, rowptr=rowptr, colptr=colptr, is_sorted=True) colptr=colptr, is_sorted=True)
return SparseTensor.from_storage(storage) return SparseTensor.from_storage(storage)
...@@ -193,8 +194,8 @@ class SparseTensor(object): ...@@ -193,8 +194,8 @@ class SparseTensor(object):
return self.from_storage(self.storage.apply(lambda x: x.cpu())) return self.from_storage(self.storage.apply(lambda x: x.cpu()))
def cuda(self, device=None, non_blocking=False, **kwargs): def cuda(self, device=None, non_blocking=False, **kwargs):
storage = self.storage.apply(lambda x: x.cuda(device, non_blocking, ** storage = self.storage.apply(
kwargs)) lambda x: x.cuda(device, non_blocking, **kwargs))
return self.from_storage(storage) return self.from_storage(storage)
@property @property
...@@ -216,8 +217,8 @@ class SparseTensor(object): ...@@ -216,8 +217,8 @@ class SparseTensor(object):
if dtype == self.dtype: if dtype == self.dtype:
return self return self
storage = self.storage.apply_value(lambda x: x.type( storage = self.storage.apply_value(
dtype, non_blocking, **kwargs)) lambda x: x.type(dtype, non_blocking, **kwargs))
return self.from_storage(storage) return self.from_storage(storage)
...@@ -286,12 +287,9 @@ class SparseTensor(object): ...@@ -286,12 +287,9 @@ class SparseTensor(object):
def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False): def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
index, value = self.coo() index, value = self.coo()
return torch.sparse_coo_tensor( return torch.sparse_coo_tensor(
index, index, value if self.has_value() else torch.ones(
value if self.has_value() else torch.ones( self.nnz(), dtype=dtype, device=self.device), self.size(),
self.nnz(), dtype=dtype, device=self.device), device=self.device, requires_grad=requires_grad)
self.size(),
device=self.device,
requires_grad=requires_grad)
def to_scipy(self, dtype=None, layout=None): def to_scipy(self, dtype=None, layout=None):
assert self.dim() == 2 assert self.dim() == 2
...@@ -392,6 +390,10 @@ SparseTensor.index_select = index_select ...@@ -392,6 +390,10 @@ SparseTensor.index_select = index_select
SparseTensor.index_select_nnz = index_select_nnz SparseTensor.index_select_nnz = index_select_nnz
SparseTensor.masked_select = masked_select SparseTensor.masked_select = masked_select
SparseTensor.masked_select_nnz = masked_select_nnz SparseTensor.masked_select_nnz = masked_select_nnz
SparseTensor.sum = torch_sparse.reduce.sum
SparseTensor.mean = torch_sparse.reduce.mean
SparseTensor.min = torch_sparse.reduce.min
SparseTensor.max = torch_sparse.reduce.max
SparseTensor.add = add SparseTensor.add = add
SparseTensor.add_nnz = add_nnz SparseTensor.add_nnz = add_nnz
...@@ -461,30 +463,38 @@ if __name__ == '__main__': ...@@ -461,30 +463,38 @@ if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu' # device = 'cpu'
# dataset = Reddit('/tmp/Reddit') dataset = Reddit('/tmp/Reddit')
dataset = Planetoid('/tmp/PubMed', 'PubMed') # dataset = Planetoid('/tmp/PubMed', 'PubMed')
data = dataset[0].to(device) data = dataset[0].to(device)
# value = torch.randn(data.num_edges, 10) value = torch.randn((data.num_edges, 8), device=device)
mat = SparseTensor(data.edge_index) mat = SparseTensor(data.edge_index, value)
perm = torch.arange(data.num_nodes) print(mat)
perm = torch.randperm(data.num_nodes)
mat1 = SparseTensor(torch.tensor([[0, 1], [0, 1]]))
mat2 = SparseTensor(torch.tensor([[0, 0, 1], [0, 1, 0]]))
add(mat1, mat2)
# print(mat2)
raise NotImplementedError
for _ in range(10):
x = torch.randn(1000, 1000, device=device).sum()
torch.cuda.synchronize()
t = time.perf_counter() t = time.perf_counter()
for _ in range(100): torch.cuda.synchronize()
mat[perm] out = mat.sum(dim=1)
torch.cuda.synchronize() torch.cuda.synchronize()
print(time.perf_counter() - t) print(time.perf_counter() - t)
print(out.size())
# perm = torch.arange(data.num_nodes)
# perm = torch.randperm(data.num_nodes)
# mat1 = SparseTensor(torch.tensor([[0, 1], [0, 1]]))
# mat2 = SparseTensor(torch.tensor([[0, 0, 1], [0, 1, 0]]))
# add(mat1, mat2)
# # print(mat2)
# raise NotImplementedError
# for _ in range(10):
# x = torch.randn(1000, 1000, device=device).sum()
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# mat[perm]
# torch.cuda.synchronize()
# print(time.perf_counter() - t)
# index = torch.tensor([ # index = torch.tensor([
# [0, 1, 1, 2, 2], # [0, 1, 1, 2, 2],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment