sub.py 2.06 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from .add import scatter_add


def scatter_sub(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
    r"""
    |

    .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
            master/docs/source/_figures/sub.svg?sanitize=true
        :align: center
        :width: 400px

    |

    Subtracts all values from the :attr:`src` tensor into :attr:`out` at the
    indices specified in the :attr:`index` tensor along an given axis
    :attr:`dim`.If multiple indices reference the same location, their
    **negated contributions add** (`cf.` :meth:`~torch_scatter.scatter_add`).

    For one-dimensional tensors, the operation computes

    .. math::
        \mathrm{out}_i = \mathrm{out}_i - \sum_j \mathrm{src}_j

rusty1s's avatar
rusty1s committed
25
26
    where :math:`\sum` is over :math:`j` such that
    :math:`\mathrm{index}_j = i`.
rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements to scatter.
        dim (int, optional): The axis along which to index.
            (default: :obj:`-1`)
        out (Tensor, optional): The destination tensor. (default: :obj:`None`)
        dim_size (int, optional): If :attr:`out` is not given, automatically
            create output with size :attr:`dim_size` at dimension :attr:`dim`.
            If :attr:`dim_size` is not given, a minimal sized output tensor is
            returned. (default: :obj:`None`)
        fill_value (int, optional): If :attr:`out` is not given, automatically
            fill output tensor with :attr:`fill_value`. (default: :obj:`0`)

    :rtype: :class:`Tensor`

    .. testsetup::

        import torch

    .. testcode::

        from torch_scatter import scatter_sub
        src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
        index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
        out = src.new_zeros((2, 6))
        out = scatter_sub(src, index, out=out)
        print(out)

    .. testoutput::

        0  0 -4 -3 -3  0
       -2 -4 -4  0  0  0
       [torch.FloatTensor of size 2x6]
    """
    return scatter_add(src.neg(), index, dim, out, dim_size, fill_value)