sub.py 2.72 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from .utils import gen_output


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    """Subtracts all values from the :attr:`input` tensor into :attr:`output`
    at the indices specified in the :attr:`index` tensor along an given axis
    :attr:`dim`. If multiple indices reference the same location, their
    **negated contributions add** (`cf.` :meth:`~torch_scatter.scatter_add_`).

    For one-dimensional tensors, the operation computes

    .. math::
        \mathrm{output}_i = \mathrm{output}_i - \sum_j \mathrm{input}_j

    where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.

    Args:
        output (Tensor): The destination tensor
        index (LongTensor): The indices of elements to scatter
        input (Tensor): The source tensor
        dim (int, optional): The axis along which to index

    :rtype: :class:`Tensor`

    .. testsetup::

        import torch

    .. testcode::

        from torch_scatter import scatter_sub_
        input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
        index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
        output = torch.zeros(2, 6)
        scatter_sub_(output, index, input, dim=1)
        print(output)

    .. testoutput::

        0  0 -4 -3 -3  0
       -2 -4 -4 -0  0  0
       [torch.FloatTensor of size 2x6]
    """
rusty1s's avatar
rusty1s committed
44
45
46
    return output.scatter_add_(dim, index, -input)


rusty1s's avatar
rename  
rusty1s committed
47
def scatter_sub(index, input, dim=0, size=None, fill_value=0):
rusty1s's avatar
rusty1s committed
48
49
50
51
    """Subtracts all values from the :attr:`input` tensor at the indices
    specified in the :attr:`index` tensor along an given axis :attr:`dim`
    (`cf.` :meth:`~torch_scatter.scatter_sub_` and
    :meth:`~torch_scatter.scatter_add`).
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    For one-dimensional tensors, the operation computes

    .. math::
        \mathrm{output}_i = \mathrm{fill\_value} - \sum_j \mathrm{input}_j

    where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.

    Args:
        index (LongTensor): The indices of elements to scatter
        input (Tensor): The source tensor
        dim (int, optional): The axis along which to index
        size (int, optional): Output size at dimension :attr:`dim`
        fill_value (int, optional): Initial filling of output tensor

    :rtype: :class:`Tensor`

    .. testsetup::

        import torch

    .. testcode::

rusty1s's avatar
rusty1s committed
75
        from torch_scatter import scatter_sub
rusty1s's avatar
rusty1s committed
76
77
78
79
80
81
82
83
84
85
86
        input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
        index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
        output = scatter_sub(index, input, dim=1)
        print(output)

    .. testoutput::

        0  0 -4 -3 -3  0
       -2 -4 -4  0  0  0
       [torch.FloatTensor of size 2x6]
    """
rusty1s's avatar
rename  
rusty1s committed
87
    output = gen_output(index, input, dim, size, fill_value)
rusty1s's avatar
rusty1s committed
88
    return scatter_sub_(output, index, input, dim)