Commit 9a608051 authored by rusty1s's avatar rusty1s
Browse files

reduce op

parent 8b78601b
import torch
import torch_scatter
from torch_scatter import segment_csr
def reduce(src, dim=None, reduce='add', deterministic=False):
if dim is None and src.has_value():
op = getattr(torch, 'sum' if reduce == 'add' else reduce)
return op(src.storage.value)
if dim is None and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
value = src.nnz() if reduce == 'add' else 1
return torch.tensor(value, device=src.device)
dims = [dim] if isinstance(dim, int) else sorted(list(dim))
assert dim[-1] < src.dim()
rowptr, col, value = src.csr()
sparse_dims = tuple(set([d for d in dims if d < 2]))
dense_dims = tuple(set([d - 1 for d in dims if d > 1]))
if len(sparse_dims) == 2 and src.has_value():
op = getattr(torch, 'sum' if reduce == 'add' else reduce)
return op(value, dim=(0, ) + dense_dims)
if len(sparse_dims) == 2 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
value = src.nnz() if reduce == 'add' else 1
return torch.tensor(value, device=src.device)
if len(dense_dims) > 0 and len(sparse_dims) == 0:
op = getattr(torch, 'sum' if reduce == 'add' else reduce)
dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
value = op(value, dim=dense_dims)
# TODO: ARGOUT
return src.set_value(value, layout='csr')
if len(dense_dims) > 0 and len(sparse_dims) > 0:
op = getattr(torch, 'sum' if reduce == 'add' else reduce)
dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
value = op(value, dim=dense_dims)
value = value[0] if isinstance(value, tuple) else value
if sparse_dims[0] == 0:
out = segment_csr(value, rowptr)
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
if sparse_dims[0] == 1 and (src.storage._csr2csc or deterministic):
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
out = segment_csr(value[csr2csc], colptr)
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
if sparse_dims[0] == 1:
op = getattr(torch_scatter, f'scatter_{reduce}')
out = op(value, col, dim=0, dim_size=src.sparse_size(0))
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment