Commit 3c6dbfa1 authored by rusty1s's avatar rusty1s
Browse files

reduction

parent 2515ce6d
......@@ -45,3 +45,4 @@ from .masked_select import masked_select, masked_select_nnz
from .diag import set_diag, remove_diag
from .add import add, add_, add_nnz, add_nnz_
from .mul import mul, mul_, mul_nnz, mul_nnz_
from .reduce import sum, mean, min, max
......@@ -9,8 +9,7 @@ from torch_sparse.tensor import SparseTensor
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
......@@ -30,8 +29,7 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
......
......@@ -25,8 +25,7 @@ def index_select(src: SparseTensor, dim: int,
device=col.device).repeat_interleave(rowcount)
perm = torch.arange(row.size(0), device=row.device)
# TODO
# perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
col = col[perm]
......@@ -54,8 +53,7 @@ def index_select(src: SparseTensor, dim: int,
device=row.device).repeat_interleave(colcount)
perm = torch.arange(col.size(0), device=col.device)
# TODO
# perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
row = row[perm]
csc2csr = (idx.size(0) * row + col).argsort()
......
......@@ -9,8 +9,7 @@ from torch_sparse.tensor import SparseTensor
def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
......@@ -30,8 +29,7 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
......
from typing import Optional
import torch
import torch_scatter
from torch_scatter import segment_csr
def reduction(src, dim=None, reduce='sum', deterministic=False):
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
if dim is None and src.has_value():
return getattr(torch, reduce)(src.storage.value)
if dim is None and not src.has_value():
value = src.nnz() if reduce in ['sum', 'add'] else 1
return torch.tensor(value, device=src.device)
dims = [dim] if isinstance(dim, int) else dim
dims = sorted([src.dim() + dim if dim < 0 else dim for dim in dims])
assert dims[-1] < src.dim()
rowptr, col, value = src.csr()
sparse_dims = tuple(set([d for d in dims if d < 2]))
dense_dims = tuple(set([d - 1 for d in dims if d > 1]))
if len(sparse_dims) == 2 and src.has_value():
return getattr(torch, reduce)(value, dim=(0, ) + dense_dims)
if len(sparse_dims) == 2 and not src.has_value():
value = src.nnz() if reduce in ['sum', 'add'] else 1
return torch.tensor(value, device=src.device)
from torch_scatter import scatter, segment_csr
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def reduction(src: SparseTensor, dim: Optional[int] = None,
reduce: str = 'sum') -> torch.Tensor:
value = src.storage.value()
if dim is None:
if value is not None:
if reduce == 'sum' or reduce == 'add':
return value.sum()
elif reduce == 'mean':
return value.mean()
elif reduce == 'min':
return value.min()
elif reduce == 'max':
return value.max()
else:
raise ValueError
else:
if reduce == 'sum' or reduce == 'add':
return torch.tensor(src.nnz(), dtype=src.dtype(),
device=src.device())
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
return torch.tensor(1, dtype=src.dtype(), device=src.device())
else:
raise ValueError
if len(dense_dims) > 0 and len(sparse_dims) == 0: # src.has_value()
dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
value = getattr(torch, reduce)(value, dim=dense_dims)
if isinstance(value, tuple):
return (src.set_value(value[0], layout='csr'), ) + value[1:]
return src.set_value(value, layout='csr')
else:
if dim < 0:
dim = src.dim() + dim
if dim == 0 and value is not None:
col = src.storage.col()
return scatter(value, col, dim=0, dim_size=src.size(0))
elif dim == 0 and value is None:
if reduce == 'sum' or reduce == 'add':
return src.storage.colcount().to(src.dtype())
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
return torch.ones(src.size(1), dtype=src.dtype())
else:
raise ValueError
elif dim == 1 and value is not None:
return segment_csr(value, src.storage.rowptr(), None, reduce)
elif dim == 1 and value is None:
if reduce == 'sum' or reduce == 'add':
return src.storage.rowcount().to(src.dtype())
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
return torch.ones(src.size(0), dtype=src.dtype())
else:
raise ValueError
elif dim > 1 and value is not None:
if reduce == 'sum' or reduce == 'add':
return value.sum(dim=dim - 1)
elif reduce == 'mean':
return value.mean(dim=dim - 1)
elif reduce == 'min':
return value.min(dim=dim - 1)[0]
elif reduce == 'max':
return value.max(dim=dim - 1)[0]
else:
raise ValueError
if len(dense_dims) > 0 and len(sparse_dims) > 0:
dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
value = getattr(torch, reduce)(value, dim=dense_dims)
value = value[0] if isinstance(value, tuple) else value
else:
raise ValueError
if sparse_dims[0] == 1 and src.has_value():
out = segment_csr(value, rowptr)
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
if sparse_dims[0] == 1 and not src.has_value():
if reduce in ['sum', 'add']:
return src.storage.rowcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency.
return torch.ones(src.size(0), device=src.device), None
else:
return torch.ones(src.size(0), device=src.device)
deterministic = src.storage.has_csr2csc() or deterministic
if sparse_dims[0] == 0 and deterministic and src.has_value():
csr2csc = src.storage.csr2csc
out = segment_csr(value[csr2csc], src.storage.colptr)
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
if sparse_dims[0] == 0 and src.has_value():
reduce = 'add' if reduce == 'sum' else reduce
func = getattr(torch_scatter, f'scatter_{reduce}')
out = func(value, col, dim=0, dim_size=src.sparse_size(1))
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
if sparse_dims[0] == 0 and not src.has_value():
if reduce in ['sum', 'add']:
return src.storage.colcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency.
return torch.ones(src.size(1), device=src.device), None
else:
return torch.ones(src.size(1), device=src.device)
@torch.jit.script
def sum(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='sum')
def sum(src, dim=None, deterministic=False):
return reduction(src, dim, reduce='sum', deterministic=deterministic)
@torch.jit.script
def mean(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='mean')
def mean(src, dim=None, deterministic=False):
return reduction(src, dim, reduce='mean', deterministic=deterministic)
@torch.jit.script
def min(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='min')
def min(src, dim=None, deterministic=False):
return reduction(src, dim, reduce='min', deterministic=deterministic)
@torch.jit.script
def max(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='max')
def max(src, dim=None, deterministic=False):
return reduction(src, dim, reduce='max', deterministic=deterministic)
SparseTensor.sum = lambda self, dim=None: sum(self, dim)
SparseTensor.mean = lambda self, dim=None: mean(self, dim)
SparseTensor.min = lambda self, dim=None: min(self, dim)
SparseTensor.max = lambda self, dim=None: max(self, dim)
......@@ -304,9 +304,8 @@ class SparseStorage(object):
if colptr is not None:
colcount = colptr[1:] - colptr[1:]
else:
raise NotImplementedError
# colcount = scatter_add(torch.ones_like(self._col), self._col,
# dim_size=self._sparse_sizes[1])
colcount = scatter_add(torch.ones_like(self._col), self._col,
dim_size=self._sparse_sizes[1])
self._colcount = colcount
return colcount
......@@ -355,8 +354,7 @@ class SparseStorage(object):
if value is not None:
ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
raise NotImplementedError
# value = segment_csr(value, ptr, reduce=reduce)
value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value
return SparseStorage(row=row, rowptr=None, col=col, value=value,
......
......@@ -380,11 +380,6 @@ class SparseTensor(object):
# return matmul(self, other, reduce='sum')
# SparseTensor.reduction = torch_sparse.reduce.reduction
# SparseTensor.sum = torch_sparse.reduce.sum
# SparseTensor.mean = torch_sparse.reduce.mean
# SparseTensor.min = torch_sparse.reduce.min
# SparseTensor.max = torch_sparse.reduce.max
# SparseTensor.matmul = matmul
# Python Bindings #############################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment