scatter.py 5.87 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from typing import Optional, Tuple

import torch

5
6
from .utils import broadcast

rusty1s's avatar
rusty1s committed
7
8
9
10

def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
11
12
    index = broadcast(index, src, dim)
    if out is None:
rusty1s's avatar
rusty1s committed
13
        size = list(src.size())
rusty1s's avatar
rusty1s committed
14
        if dim_size is not None:
15
            size[dim] = dim_size
rusty1s's avatar
rusty1s committed
16
17
18
19
20
        elif index.numel() == 0:
            size[dim] = 0
        else:
            size[dim] = int(index.max()) + 1
        out = torch.zeros(size, dtype=src.dtype, device=src.device)
21
22
23
        return out.scatter_add_(dim, index, src)
    else:
        return out.scatter_add_(dim, index, src)
rusty1s's avatar
rusty1s committed
24
25
26
27
28


def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
29
    return scatter_sum(src, index, dim, out, dim_size)
rusty1s's avatar
rusty1s committed
30
31


rusty1s's avatar
rusty1s committed
32
33
34
35
36
37
def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
    return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)


rusty1s's avatar
rusty1s committed
38
39
40
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                 out: Optional[torch.Tensor] = None,
                 dim_size: Optional[int] = None) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
41
42
43
44
45
46
47

    out = scatter_sum(src, index, dim, out, dim_size)
    dim_size = out.size(dim)

    index_dim = dim
    if index_dim < 0:
        index_dim = index_dim + src.dim()
rusty1s's avatar
rusty1s committed
48
    if index.dim() <= index_dim:
rusty1s's avatar
rusty1s committed
49
50
51
52
        index_dim = index.dim() - 1

    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
    count = scatter_sum(ones, index, index_dim, None, dim_size)
Jacob Zhong's avatar
Jacob Zhong committed
53
    count[count < 1] = 1
rusty1s's avatar
rusty1s committed
54
    count = broadcast(count, out, dim)
rusty1s's avatar
rusty1s committed
55
56
57
58
    if out.is_floating_point():
        out.true_divide_(count)
    else:
        out.floor_divide_(count)
rusty1s's avatar
rusty1s committed
59
    return out
rusty1s's avatar
rusty1s committed
60
61


rusty1s's avatar
rusty1s committed
62
63
64
65
def scatter_min(
        src: torch.Tensor, index: torch.Tensor, dim: int = -1,
        out: Optional[torch.Tensor] = None,
        dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
rusty1s's avatar
rusty1s committed
66
67
68
    return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)


rusty1s's avatar
rusty1s committed
69
70
71
72
def scatter_max(
        src: torch.Tensor, index: torch.Tensor, dim: int = -1,
        out: Optional[torch.Tensor] = None,
        dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
rusty1s's avatar
rusty1s committed
73
74
75
76
77
78
    return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)


def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
            out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
            reduce: str = "sum") -> torch.Tensor:
rusty1s's avatar
rusty1s committed
79
80
81
82
83
84
85
86
87
88
    r"""
    |

    .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
            master/docs/source/_figures/add.svg?sanitize=true
        :align: center
        :width: 400px

    |

rusty1s's avatar
rusty1s committed
89
90
91
92
93
94
95
96
97
98
99
100
101
    Reduces all values from the :attr:`src` tensor into :attr:`out` at the
    indices specified in the :attr:`index` tensor along a given axis
    :attr:`dim`.
    For each value in :attr:`src`, its output index is specified by its index
    in :attr:`src` for dimensions outside of :attr:`dim` and by the
    corresponding value in :attr:`index` for dimension :attr:`dim`.
    The applied reduction is defined via the :attr:`reduce` argument.

    Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
    tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
    and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
    tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
    Moreover, the values of :attr:`index` must be between :math:`0` and
rusty1s's avatar
rusty1s committed
102
    :math:`y - 1`, although no specific ordering of indices is required.
rusty1s's avatar
rusty1s committed
103
104
105
106
107
    The :attr:`index` tensor supports broadcasting in case its dimensions do
    not match with :attr:`src`.

    For one-dimensional tensors with :obj:`reduce="sum"`, the operation
    computes
rusty1s's avatar
rusty1s committed
108
109

    .. math::
rusty1s's avatar
rusty1s committed
110
        \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
rusty1s's avatar
rusty1s committed
111
112
113
114

    where :math:`\sum_j` is over :math:`j` such that
    :math:`\mathrm{index}_j = i`.

rusty1s's avatar
rusty1s committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    .. note::

        This operation is implemented via atomic operations on the GPU and is
        therefore **non-deterministic** since the order of parallel operations
        to the same value is undetermined.
        For floating-point variables, this results in a source of variance in
        the result.

    :param src: The source tensor.
    :param index: The indices of elements to scatter.
    :param dim: The axis along which to index. (default: :obj:`-1`)
    :param out: The destination tensor.
    :param dim_size: If :attr:`out` is not given, automatically create output
        with size :attr:`dim_size` at dimension :attr:`dim`.
        If :attr:`dim_size` is not given, a minimal sized output tensor
        according to :obj:`index.max() + 1` is returned.
rusty1s's avatar
rusty1s committed
131
132
    :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
        :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
rusty1s's avatar
rusty1s committed
133
134
135

    :rtype: :class:`Tensor`

rusty1s's avatar
rusty1s committed
136
    .. code-block:: python
rusty1s's avatar
rusty1s committed
137

rusty1s's avatar
rusty1s committed
138
        from torch_scatter import scatter
rusty1s's avatar
rusty1s committed
139

rusty1s's avatar
rusty1s committed
140
141
        src = torch.randn(10, 6, 64)
        index = torch.tensor([0, 1, 0, 1, 2, 1])
rusty1s's avatar
rusty1s committed
142

rusty1s's avatar
rusty1s committed
143
144
        # Broadcasting in the first and last dim.
        out = scatter(src, index, dim=1, reduce="sum")
rusty1s's avatar
rusty1s committed
145

rusty1s's avatar
rusty1s committed
146
        print(out.size())
rusty1s's avatar
rusty1s committed
147

rusty1s's avatar
rusty1s committed
148
    .. code-block::
rusty1s's avatar
rusty1s committed
149

rusty1s's avatar
rusty1s committed
150
        torch.Size([10, 3, 64])
rusty1s's avatar
rusty1s committed
151
    """
rusty1s's avatar
rusty1s committed
152
153
    if reduce == 'sum' or reduce == 'add':
        return scatter_sum(src, index, dim, out, dim_size)
rusty1s's avatar
rusty1s committed
154
155
    if reduce == 'mul':
        return scatter_mul(src, index, dim, out, dim_size)
rusty1s's avatar
rusty1s committed
156
157
158
159
160
161
162
163
    elif reduce == 'mean':
        return scatter_mean(src, index, dim, out, dim_size)
    elif reduce == 'min':
        return scatter_min(src, index, dim, out, dim_size)[0]
    elif reduce == 'max':
        return scatter_max(src, index, dim, out, dim_size)[0]
    else:
        raise ValueError