reduce.py 3.2 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from torch_scatter import scatter, segment_csr
from torch_sparse.tensor import SparseTensor


def reduction(src: SparseTensor, dim: Optional[int] = None,
              reduce: str = 'sum') -> torch.Tensor:
    value = src.storage.value()

    if dim is None:
        if value is not None:
            if reduce == 'sum' or reduce == 'add':
                return value.sum()
            elif reduce == 'mean':
                return value.mean()
            elif reduce == 'min':
                return value.min()
            elif reduce == 'max':
                return value.max()
            else:
                raise ValueError
        else:
            if reduce == 'sum' or reduce == 'add':
                return torch.tensor(src.nnz(), dtype=src.dtype(),
                                    device=src.device())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.tensor(1, dtype=src.dtype(), device=src.device())
            else:
                raise ValueError
    else:
        if dim < 0:
            dim = src.dim() + dim

        if dim == 0 and value is not None:
            col = src.storage.col()
Matthias Fey's avatar
Matthias Fey committed
38
            return scatter(value, col, 0, None, src.size(1), reduce)
rusty1s's avatar
rusty1s committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        elif dim == 0 and value is None:
            if reduce == 'sum' or reduce == 'add':
                return src.storage.colcount().to(src.dtype())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.ones(src.size(1), dtype=src.dtype())
            else:
                raise ValueError
        elif dim == 1 and value is not None:
            return segment_csr(value, src.storage.rowptr(), None, reduce)
        elif dim == 1 and value is None:
            if reduce == 'sum' or reduce == 'add':
                return src.storage.rowcount().to(src.dtype())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.ones(src.size(0), dtype=src.dtype())
            else:
                raise ValueError
        elif dim > 1 and value is not None:
            if reduce == 'sum' or reduce == 'add':
                return value.sum(dim=dim - 1)
            elif reduce == 'mean':
                return value.mean(dim=dim - 1)
            elif reduce == 'min':
                return value.min(dim=dim - 1)[0]
            elif reduce == 'max':
                return value.max(dim=dim - 1)[0]
            else:
                raise ValueError
        else:
            raise ValueError
rusty1s's avatar
rusty1s committed
68
69


rusty1s's avatar
rusty1s committed
70
71
def sum(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
    return reduction(src, dim, reduce='sum')
rusty1s's avatar
rusty1s committed
72

73

rusty1s's avatar
rusty1s committed
74
75
def mean(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
    return reduction(src, dim, reduce='mean')
76
77


rusty1s's avatar
rusty1s committed
78
79
def min(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
    return reduction(src, dim, reduce='min')
80
81


rusty1s's avatar
rusty1s committed
82
83
def max(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
    return reduction(src, dim, reduce='max')
84
85


rusty1s's avatar
rusty1s committed
86
87
88
89
SparseTensor.sum = lambda self, dim=None: sum(self, dim)
SparseTensor.mean = lambda self, dim=None: mean(self, dim)
SparseTensor.min = lambda self, dim=None: min(self, dim)
SparseTensor.max = lambda self, dim=None: max(self, dim)