add.py 3.67 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from .utils import gen_output


def scatter_add_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
5
    r"""
rusty1s's avatar
rusty1s committed
6
7
    |

rusty1s's avatar
rusty1s committed
8
9
10
11
12
    .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
            master/docs/source/_figures/add.svg?sanitize=true
        :align: center
        :width: 400px

rusty1s's avatar
rusty1s committed
13
14
    |

rusty1s's avatar
rusty1s committed
15
    Sums all values from the :attr:`input` tensor into :attr:`output` at the
rusty1s's avatar
rusty1s committed
16
    indices specified in the :attr:`index` tensor along an given axis
rusty1s's avatar
rusty1s committed
17
    :attr:`dim`. For each value in :attr:`input`, its output index is specified
rusty1s's avatar
rusty1s committed
18
19
    by its index in :attr:`input` for dimensions outside of :attr:`dim` and by
    the corresponding value in :attr:`index` for dimension :attr:`dim`. If
rusty1s's avatar
rusty1s committed
20
    multiple indices reference the same location, their **contributions add**.
rusty1s's avatar
rusty1s committed
21
22
23

    If :attr:`input` and :attr:`index` are n-dimensional tensors with size
    :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
rusty1s's avatar
rusty1s committed
24
    :attr:`dim` = `i`, then :attr:`output` must be an n-dimensional tensor with
rusty1s's avatar
rusty1s committed
25
26
27
28
    size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
    values of :attr:`index` must be between `0` and `output.size(dim) - 1`.

    For one-dimensional tensors, the operation computes
rusty1s's avatar
rusty1s committed
29
30
31
32
33

    .. math::
        \mathrm{output}_i = \mathrm{output}_i + \sum_j \mathrm{input}_j

    where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.
rusty1s's avatar
rusty1s committed
34
35
36
37
38
39
40

    Args:
        output (Tensor): The destination tensor
        index (LongTensor): The indices of elements to scatter
        input (Tensor): The source tensor
        dim (int, optional): The axis along which to index

rusty1s's avatar
rusty1s committed
41
42
    :rtype: :class:`Tensor`

rusty1s's avatar
rusty1s committed
43
    .. testsetup::
rusty1s's avatar
rusty1s committed
44

rusty1s's avatar
rusty1s committed
45
46
47
48
        import torch

    .. testcode::

rusty1s's avatar
rusty1s committed
49
        from torch_scatter import scatter_add_
rusty1s's avatar
rusty1s committed
50
        input =     torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
        index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
        output = torch.zeros(2, 6)
        scatter_add_(output, index, input, dim=1)
        print(output)

    .. testoutput::
rusty1s's avatar
rusty1s committed
57

rusty1s's avatar
rusty1s committed
58
59
        0  0  4  3  3  0
        2  4  4  0  0  0
rusty1s's avatar
rusty1s committed
60
       [torch.FloatTensor of size 2x6]
rusty1s's avatar
rusty1s committed
61
62
63
64
65
    """
    return output.scatter_add_(dim, index, input)


def scatter_add(index, input, dim=0, size=None, fill_value=0):
rusty1s's avatar
rusty1s committed
66
    r"""Sums all values from the :attr:`input` tensor at the indices specified
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
73
    in the :attr:`index` tensor along an given axis :attr:`dim` (`cf.`
    :meth:`~torch_scatter.scatter_add_`).

    The output size at dimension :attr:`dim` is given by :attr:`size` and must
    be at least size `index.max(dim) - 1`. If :attr:`size` is not given, a
    minimal sized output tensor is returned. The output tensor is prefilled
    with the specified value from :attr:`fill_value`.
rusty1s's avatar
rusty1s committed
74
75
76

    For one-dimensional tensors, the operation computes

rusty1s's avatar
rusty1s committed
77
78
79
80
81
    .. math::
        \mathrm{output}_i = \mathrm{fill\_value} + \sum_j \mathrm{input}_j

    where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.

rusty1s's avatar
rusty1s committed
82
83
84
85
86
87
    Args:
        index (LongTensor): The indices of elements to scatter
        input (Tensor): The source tensor
        dim (int, optional): The axis along which to index
        size (int, optional): Output size at dimension :attr:`dim`
        fill_value (int, optional): Initial filling of output tensor
rusty1s's avatar
rusty1s committed
88

rusty1s's avatar
rusty1s committed
89
90
    :rtype: :class:`Tensor`

rusty1s's avatar
rusty1s committed
91
92
93
94
95
96
    .. testsetup::

        import torch

    .. testcode::

rusty1s's avatar
rusty1s committed
97
        from torch_scatter import scatter_add
rusty1s's avatar
rusty1s committed
98
        input =     torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
rusty1s's avatar
rusty1s committed
99
100
101
102
103
104
105
106
107
        index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
        output = scatter_add(index, input, dim=1)
        print(output)

    .. testoutput::

        0  0  4  3  3  0
        2  4  4  0  0  0
       [torch.FloatTensor of size 2x6]
rusty1s's avatar
rusty1s committed
108
109
110
    """
    output = gen_output(index, input, dim, size, fill_value)
    return scatter_add_(output, index, input, dim)