"cacheflow/vscode:/vscode.git/clone" did not exist on "bb59a3e7302ad6892e097eee4040e3f516e9f4ea"
reduce.py 3.31 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
4
import torch
import torch_scatter
rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from torch_scatter import scatter, segment_csr

from torch_sparse.tensor import SparseTensor


@torch.jit.script
def reduction(src: SparseTensor, dim: Optional[int] = None,
              reduce: str = 'sum') -> torch.Tensor:
    value = src.storage.value()

    if dim is None:
        if value is not None:
            if reduce == 'sum' or reduce == 'add':
                return value.sum()
            elif reduce == 'mean':
                return value.mean()
            elif reduce == 'min':
                return value.min()
            elif reduce == 'max':
                return value.max()
            else:
                raise ValueError
        else:
            if reduce == 'sum' or reduce == 'add':
                return torch.tensor(src.nnz(), dtype=src.dtype(),
                                    device=src.device())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.tensor(1, dtype=src.dtype(), device=src.device())
            else:
                raise ValueError

    else:
        if dim < 0:
            dim = src.dim() + dim

        if dim == 0 and value is not None:
            col = src.storage.col()
            return scatter(value, col, dim=0, dim_size=src.size(0))
        elif dim == 0 and value is None:
            if reduce == 'sum' or reduce == 'add':
                return src.storage.colcount().to(src.dtype())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.ones(src.size(1), dtype=src.dtype())
            else:
                raise ValueError
        elif dim == 1 and value is not None:
            return segment_csr(value, src.storage.rowptr(), None, reduce)
        elif dim == 1 and value is None:
            if reduce == 'sum' or reduce == 'add':
                return src.storage.rowcount().to(src.dtype())
            elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
                return torch.ones(src.size(0), dtype=src.dtype())
            else:
                raise ValueError
        elif dim > 1 and value is not None:
            if reduce == 'sum' or reduce == 'add':
                return value.sum(dim=dim - 1)
            elif reduce == 'mean':
                return value.mean(dim=dim - 1)
            elif reduce == 'min':
                return value.min(dim=dim - 1)[0]
            elif reduce == 'max':
                return value.max(dim=dim - 1)[0]
            else:
                raise ValueError
rusty1s's avatar
rusty1s committed
70

rusty1s's avatar
rusty1s committed
71
72
        else:
            raise ValueError
rusty1s's avatar
rusty1s committed
73
74


rusty1s's avatar
rusty1s committed
75
76
77
@torch.jit.script
def sum(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
    return reduction(src, dim, reduce='sum')
rusty1s's avatar
rusty1s committed
78

79

rusty1s's avatar
rusty1s committed
80
81
82
@torch.jit.script
def mean(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
    return reduction(src, dim, reduce='mean')
83
84


rusty1s's avatar
rusty1s committed
85
86
87
@torch.jit.script
def min(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
    return reduction(src, dim, reduce='min')
88
89


rusty1s's avatar
rusty1s committed
90
91
92
@torch.jit.script
def max(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
    return reduction(src, dim, reduce='max')
93
94


rusty1s's avatar
rusty1s committed
95
96
97
98
SparseTensor.sum = lambda self, dim=None: sum(self, dim)
SparseTensor.mean = lambda self, dim=None: mean(self, dim)
SparseTensor.min = lambda self, dim=None: min(self, dim)
SparseTensor.max = lambda self, dim=None: max(self, dim)