reduce.py 3.73 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
import torch_scatter
from torch_scatter import segment_csr
4
5


rusty1s's avatar
rusty1s committed
6
7
def reduction(src, dim=None, reduce='sum', deterministic=False):
    assert reduce in ['sum', 'mean', 'min', 'max']
rusty1s's avatar
typos  
rusty1s committed
8

9
    if dim is None and src.has_value():
rusty1s's avatar
rusty1s committed
10
        return getattr(torch, reduce)(src.storage.value)
rusty1s's avatar
rusty1s committed
11
12

    if dim is None and not src.has_value():
rusty1s's avatar
rusty1s committed
13
        value = src.nnz() if reduce == 'sum' else 1
rusty1s's avatar
rusty1s committed
14
15
        return torch.tensor(value, device=src.device)

rusty1s's avatar
rusty1s committed
16
17
    dims = [dim] if isinstance(dim, int) else dim
    dims = sorted([src.dim() + dim if dim < 0 else dim for dim in dims])
rusty1s's avatar
rusty1s committed
18
    assert dims[-1] < src.dim()
rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25

    rowptr, col, value = src.csr()

    sparse_dims = tuple(set([d for d in dims if d < 2]))
    dense_dims = tuple(set([d - 1 for d in dims if d > 1]))

    if len(sparse_dims) == 2 and src.has_value():
rusty1s's avatar
rusty1s committed
26
        return getattr(torch, reduce)(value, dim=(0, ) + dense_dims)
rusty1s's avatar
rusty1s committed
27
28

    if len(sparse_dims) == 2 and not src.has_value():
rusty1s's avatar
rusty1s committed
29
        value = src.nnz() if reduce == 'sum' else 1
rusty1s's avatar
rusty1s committed
30
31
        return torch.tensor(value, device=src.device)

rusty1s's avatar
rusty1s committed
32
    if len(dense_dims) > 0 and len(sparse_dims) == 0:  # src.has_value()
rusty1s's avatar
rusty1s committed
33
        dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
rusty1s's avatar
rusty1s committed
34
        value = getattr(torch, reduce)(value, dim=dense_dims)
rusty1s's avatar
bugfix  
rusty1s committed
35
        if isinstance(value, tuple):
36
            return (src.set_value(value[0], layout='csr'), ) + value[1:]
rusty1s's avatar
rusty1s committed
37
38
39
40
        return src.set_value(value, layout='csr')

    if len(dense_dims) > 0 and len(sparse_dims) > 0:
        dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
rusty1s's avatar
rusty1s committed
41
        value = getattr(torch, reduce)(value, dim=dense_dims)
rusty1s's avatar
rusty1s committed
42
43
        value = value[0] if isinstance(value, tuple) else value

rusty1s's avatar
rusty1s committed
44
    if sparse_dims[0] == 1 and src.has_value():
rusty1s's avatar
rusty1s committed
45
46
47
48
        out = segment_csr(value, rowptr)
        out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
        return out

rusty1s's avatar
rusty1s committed
49
    if sparse_dims[0] == 1 and not src.has_value():
rusty1s's avatar
rusty1s committed
50
        if reduce == 'sum':
rusty1s's avatar
rusty1s committed
51
52
            return src.storage.rowcount.to(torch.get_default_dtype())
        elif reduce == 'min' or 'max':
rusty1s's avatar
typos  
rusty1s committed
53
            # Return an additional `None` arg(min|max) tensor for consistency.
rusty1s's avatar
rusty1s committed
54
55
56
57
            return torch.ones(src.size(0), device=src.device), None
        else:
            return torch.ones(src.size(0), device=src.device)

rusty1s's avatar
typos  
rusty1s committed
58
    deterministic = src.storage.has_csr2csc() or deterministic
rusty1s's avatar
rusty1s committed
59
60

    if sparse_dims[0] == 0 and deterministic and src.has_value():
rusty1s's avatar
fixes  
rusty1s committed
61
62
        csr2csc = src.storage.csr2csc
        out = segment_csr(value[csr2csc], src.storage.colptr)
rusty1s's avatar
rusty1s committed
63
64
65
        out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
        return out

rusty1s's avatar
rusty1s committed
66
    if sparse_dims[0] == 0 and src.has_value():
rusty1s's avatar
rusty1s committed
67
        reduce = 'add' if reduce == 'sum' else reduce
68
        func = getattr(torch_scatter, f'scatter_{reduce}')
rusty1s's avatar
rusty1s committed
69
        out = func(value, col, dim=0, dim_size=src.sparse_size(1))
rusty1s's avatar
rusty1s committed
70
71
        out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
        return out
72

rusty1s's avatar
rusty1s committed
73
    if sparse_dims[0] == 0 and not src.has_value():
rusty1s's avatar
rusty1s committed
74
        if reduce == 'sum':
rusty1s's avatar
rusty1s committed
75
76
            return src.storage.colcount.to(torch.get_default_dtype())
        elif reduce == 'min' or 'max':
rusty1s's avatar
typos  
rusty1s committed
77
            # Return an additional `None` arg(min|max) tensor for consistency.
rusty1s's avatar
rusty1s committed
78
79
80
81
            return torch.ones(src.size(1), device=src.device), None
        else:
            return torch.ones(src.size(1), device=src.device)

82
83

def sum(src, dim=None, deterministic=False):
rusty1s's avatar
rusty1s committed
84
    return reduction(src, dim, reduce='sum', deterministic=deterministic)
85
86
87


def mean(src, dim=None, deterministic=False):
rusty1s's avatar
rusty1s committed
88
    return reduction(src, dim, reduce='mean', deterministic=deterministic)
89
90
91


def min(src, dim=None, deterministic=False):
rusty1s's avatar
rusty1s committed
92
    return reduction(src, dim, reduce='min', deterministic=deterministic)
93
94
95


def max(src, dim=None, deterministic=False):
rusty1s's avatar
rusty1s committed
96
    return reduction(src, dim, reduce='max', deterministic=deterministic)