reduce.py 2.97 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
import torch_scatter
from torch_scatter import segment_csr
4
5
6
7
8
9


def __reduce__(src, dim=None, reduce='add', deterministic=False):
    if dim is None and src.has_value():
        func = getattr(torch, 'sum' if reduce == 'add' else reduce)
        return func(src.storage.value)
rusty1s's avatar
rusty1s committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

    if dim is None and not src.has_value():
        assert reduce in ['add', 'mean', 'min', 'max']
        value = src.nnz() if reduce == 'add' else 1
        return torch.tensor(value, device=src.device)

    dims = [dim] if isinstance(dim, int) else sorted(list(dim))
    assert dim[-1] < src.dim()

    rowptr, col, value = src.csr()

    sparse_dims = tuple(set([d for d in dims if d < 2]))
    dense_dims = tuple(set([d - 1 for d in dims if d > 1]))

    if len(sparse_dims) == 2 and src.has_value():
25
26
        func = getattr(torch, 'sum' if reduce == 'add' else reduce)
        return func(value, dim=(0, ) + dense_dims)
rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33

    if len(sparse_dims) == 2 and not src.has_value():
        assert reduce in ['add', 'mean', 'min', 'max']
        value = src.nnz() if reduce == 'add' else 1
        return torch.tensor(value, device=src.device)

    if len(dense_dims) > 0 and len(sparse_dims) == 0:
34
        func = getattr(torch, 'sum' if reduce == 'add' else reduce)
rusty1s's avatar
rusty1s committed
35
        dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
36
        value = func(value, dim=dense_dims)
rusty1s's avatar
bugfix  
rusty1s committed
37
        if isinstance(value, tuple):
38
            return (src.set_value(value[0], layout='csr'), ) + value[1:]
rusty1s's avatar
rusty1s committed
39
40
41
        return src.set_value(value, layout='csr')

    if len(dense_dims) > 0 and len(sparse_dims) > 0:
42
        func = getattr(torch, 'sum' if reduce == 'add' else reduce)
rusty1s's avatar
rusty1s committed
43
        dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
44
        value = func(value, dim=dense_dims)
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        value = value[0] if isinstance(value, tuple) else value

    if sparse_dims[0] == 0:
        out = segment_csr(value, rowptr)
        out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
        return out

    if sparse_dims[0] == 1 and (src.storage._csr2csc or deterministic):
        csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
        out = segment_csr(value[csr2csc], colptr)
        out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
        return out

    if sparse_dims[0] == 1:
59
60
        func = getattr(torch_scatter, f'scatter_{reduce}')
        out = func(value, col, dim=0, dim_size=src.sparse_size(0))
rusty1s's avatar
rusty1s committed
61
62
        out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
        return out
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78


def sum(src, dim=None, deterministic=False):
    return __reduce__(src, dim, reduce='add', deterministic=deterministic)


def mean(src, dim=None, deterministic=False):
    return __reduce__(src, dim, reduce='mean', deterministic=deterministic)


def min(src, dim=None, deterministic=False):
    return __reduce__(src, dim, reduce='min', deterministic=deterministic)


def max(src, dim=None, deterministic=False):
    return __reduce__(src, dim, reduce='max', deterministic=deterministic)