__init__.py 5.68 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from .scatter import scatter
rusty1s's avatar
cleaner  
rusty1s committed
2
from .utils import gen_filled_tensor, gen_output
rusty1s's avatar
rusty1s committed
3
4
5


def scatter_add_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
6
7
8
    """ -> Tensor

    Sums up all values from the tensor :attr:`input` into :attr:`output` at
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
    the indices specified in the :attr:`index` tensor along an given axis
    :attr:`dim`. For each value in :attr:`input`, its output index is specified
    by its index in :attr:`input` for dimension != :attr:`dim` and by the
    corresponding value in :attr:`index` for dimension = :attr:`dim`. If
    multiple indices reference the same location, their contributions add.

rusty1s's avatar
rusty1s committed
15
16
    If :attr:`input` and :attr:`index` are n-dimensional tensors with size
    :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
rusty1s's avatar
rusty1s committed
17
18
19
20
    :attr:`dim` = i, then :attr:`output` must be an n-dimensional tensor with
    size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
    values of :attr:`index` must be between `0` and `output.size(dim) - 1`.

rusty1s's avatar
typo  
rusty1s committed
21
22
    For one-dimensional tensors, the operation computes
    :math:`output_i = output_i + \sum_j input_j`, where sum is over
rusty1s's avatar
typo  
rusty1s committed
23
    :math:`j` such that :math:`index_j = i`.
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
26
27
28
    Args:
        output (Tensor): The destination tensor
        index (LongTensor): The indices of elements to scatter
        input (Tensor): The source tensor
rusty1s's avatar
rusty1s committed
29
        dim (int, optional): The axis along which to index
rusty1s's avatar
rusty1s committed
30
31
32
33
34
35
36
37
38
39

    Example::
        >> input = torch.Tensor([[2, 0, 1, 4, 3], [0,2, 1, 3, 4]])
        >> index = torch.LongTensor([[4, 5, 2, 3], [0, 0, 2, 2, 1]])
        >> output = torch.zeros(2, 6)
        >> scatter_add_(output, index, input, dim=1)
        0  0  4  3  3  0
        2  4  4  0  0  0
        [torch.FloatTensor of size 2x6]
    """
rusty1s's avatar
rusty1s committed
40
    return output.scatter_add_(dim, index, input)
rusty1s's avatar
rusty1s committed
41
42


rusty1s's avatar
rusty1s committed
43
44
45
46
def scatter_add(index, input, dim=0, size=None, fill_value=0):
    """ -> Tensor

    Sums ap all values from the tensor :attr:`input` at the indices
rusty1s's avatar
rusty1s committed
47
    specified in the :attr:`index` tensor along an given axis :attr:`dim`.
rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
54
55
    The output size at dimension :attr:`dim` is given by :attr:`size` and must
    be at least size `index.max(dim) - 1`. If :attr:`size` is not given, a
    minimal sized output tensor is returned. The output tensor is initially
    filled with the specified value at :attr:`fill_value`.

    For one-dimensional tensors, the operation computes
    :math:`output_i = fill_value + \sum_j input_j`, where sum is over
    :math:`j` such that :math:`index_j = i`.
rusty1s's avatar
rusty1s committed
56

rusty1s's avatar
rusty1s committed
57
    A more detailed explanation is described in :meth:`~scatter_add_`.
rusty1s's avatar
rusty1s committed
58
59
60
61

    Args:
        index (LongTensor): The indices of elements to scatter
        input (Tensor): The source tensor
rusty1s's avatar
rusty1s committed
62
63
64
        dim (int, optional): The axis along which to index
        size (int, optional): Output size at dimension :attr:`dim`
        fill_value (int, optional): Initial filling of output tensor
rusty1s's avatar
rusty1s committed
65
    """
rusty1s's avatar
rusty1s committed
66
    output = gen_output(index, input, dim, size, fill_value)
rusty1s's avatar
rusty1s committed
67
68
69
70
    return scatter_add_(output, index, input, dim)


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
71
72
    """If multiple indices reference the same location, their negated
    contributions add."""
rusty1s's avatar
rusty1s committed
73
    return output.scatter_add_(dim, index, -input)
rusty1s's avatar
rusty1s committed
74
75
76
77
78
79
80
81


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
82
83
    """If multiple indices reference the same location, their
    contributions multiply."""
rusty1s's avatar
rusty1s committed
84
    return scatter('mul', dim, output, index, input)
rusty1s's avatar
rusty1s committed
85
86
87
88
89
90
91
92


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
93
94
    """If multiple indices reference the same location, their
    contributions divide."""
rusty1s's avatar
rusty1s committed
95
    return scatter('div', dim, output, index, input)
rusty1s's avatar
rusty1s committed
96
97
98
99


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
100
101
102
103
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
104
105
    """If multiple indices reference the same location, their
    contributions average."""
rusty1s's avatar
rename  
rusty1s committed
106
107
108
109
    num_output = gen_filled_tensor(output, output.size(), fill_value=0)
    scatter('mean', dim, output, index, input, num_output)
    num_output[num_output == 0] = 1
    output /= num_output
rusty1s's avatar
rusty1s committed
110
111
112
    return output


rusty1s's avatar
cleaner  
rusty1s committed
113
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
rusty1s's avatar
rusty1s committed
114
115
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
116
117


rusty1s's avatar
rusty1s committed
118
def scatter_max_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
119
120
    """If multiple indices reference the same location, the maximal
    contribution gets taken."""
rusty1s's avatar
rename  
rusty1s committed
121
122
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('max', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
123
124
125
126
127
128
129
130


def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_max_(output, index, input, dim)


def scatter_min_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
131
132
    """If multiple indices reference the same location, the minimal
    contribution gets taken."""
rusty1s's avatar
rename  
rusty1s committed
133
134
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('min', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
135
136
137
138
139
140
141


def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_min_(output, index, input, dim)


rusty1s's avatar
rusty1s committed
142
143
__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
144
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
rusty1s's avatar
rusty1s committed
145
146
    'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
    'scatter_min_', 'scatter_min'
rusty1s's avatar
rusty1s committed
147
]