__init__.py 5.45 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from .scatter import scatter
rusty1s's avatar
cleaner  
rusty1s committed
2
from .utils import gen_filled_tensor, gen_output
rusty1s's avatar
rusty1s committed
3
4
5


def scatter_add_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
12
13
14
15
16
17
18
    """Sums up all values from the tensor :attr:`input` into :attr:`output` at
    the indices specified in the :attr:`index` tensor along an given axis
    :attr:`dim`. For each value in :attr:`input`, its output index is specified
    by its index in :attr:`input` for dimension != :attr:`dim` and by the
    corresponding value in :attr:`index` for dimension = :attr:`dim`. If
    multiple indices reference the same location, their contributions add.

    If :attr:`input` and :attr:`index` are n-dimensional tensors with equal
    size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
    :attr:`dim` = i, then :attr:`output` must be an n-dimensional tensor with
    size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
    values of :attr:`index` must be between `0` and `output.size(dim) - 1`.

rusty1s's avatar
typo  
rusty1s committed
19
20
    For one-dimensional tensors, the operation computes
    :math:`output_i = output_i + \sum_j input_j`, where sum is over
rusty1s's avatar
typo  
rusty1s committed
21
    :math:`j` such that :math:`index_j = i`.
rusty1s's avatar
rusty1s committed
22

rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    Args:
        output (Tensor): The destination tensor
        index (LongTensor): The indices of elements to scatter
        input (Tensor): The source tensor
        dim (int): The axis along which to index

    Example::
        >> input = torch.Tensor([[2, 0, 1, 4, 3], [0,2, 1, 3, 4]])
        >> index = torch.LongTensor([[4, 5, 2, 3], [0, 0, 2, 2, 1]])
        >> output = torch.zeros(2, 6)
        >> scatter_add_(output, index, input, dim=1)
        0  0  4  3  3  0
        2  4  4  0  0  0
        [torch.FloatTensor of size 2x6]
    """
rusty1s's avatar
rusty1s committed
38
    return output.scatter_add_(dim, index, input)
rusty1s's avatar
rusty1s committed
39
40
41


def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
rusty1s's avatar
rusty1s committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    """Sums ap all values from the tensor :attr:`input` at the indices
    specified in the :attr:`index` tensor along an given axis :attr:`dim`.
    The output size is given by :attr:`max_index` and must be at least size
    `index.max(dim) - 1`. If `max_index` is not given, a minimal sized output
    tensor is returned. The output tensor is filled with :attr:`fill_value`
    before scatter.

    A more detailed explanation of the operation is described in
    :meth:`~scatter_add_`.

    Args:
        index (LongTensor): The indices of elements to scatter
        input (Tensor): The source tensor
        dim (int): The axis along which to index
        max_index (int): Output size at dimension :attr:`dim`
        fill_value (int): Fill value of output before scatter
    """
rusty1s's avatar
rusty1s committed
59
60
61
62
63
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_add_(output, index, input, dim)


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
64
65
    """If multiple indices reference the same location, their negated
    contributions add."""
rusty1s's avatar
rusty1s committed
66
    return output.scatter_add_(dim, index, -input)
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
73
74


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
75
76
    """If multiple indices reference the same location, their
    contributions multiply."""
rusty1s's avatar
rusty1s committed
77
    return scatter('mul', dim, output, index, input)
rusty1s's avatar
rusty1s committed
78
79
80
81
82
83
84
85


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
86
87
    """If multiple indices reference the same location, their
    contributions divide."""
rusty1s's avatar
rusty1s committed
88
    return scatter('div', dim, output, index, input)
rusty1s's avatar
rusty1s committed
89
90
91
92


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
93
94
95
96
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
97
98
    """If multiple indices reference the same location, their
    contributions average."""
rusty1s's avatar
rename  
rusty1s committed
99
100
101
102
    num_output = gen_filled_tensor(output, output.size(), fill_value=0)
    scatter('mean', dim, output, index, input, num_output)
    num_output[num_output == 0] = 1
    output /= num_output
rusty1s's avatar
rusty1s committed
103
104
105
    return output


rusty1s's avatar
cleaner  
rusty1s committed
106
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
rusty1s's avatar
rusty1s committed
107
108
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
109
110


rusty1s's avatar
rusty1s committed
111
def scatter_max_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
112
113
    """If multiple indices reference the same location, the maximal
    contribution gets taken."""
rusty1s's avatar
rename  
rusty1s committed
114
115
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('max', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
116
117
118
119
120
121
122
123


def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_max_(output, index, input, dim)


def scatter_min_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
124
125
    """If multiple indices reference the same location, the minimal
    contribution gets taken."""
rusty1s's avatar
rename  
rusty1s committed
126
127
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('min', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
128
129
130
131
132
133
134


def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_min_(output, index, input, dim)


rusty1s's avatar
rusty1s committed
135
136
__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
137
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
rusty1s's avatar
rusty1s committed
138
139
    'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
    'scatter_min_', 'scatter_min'
rusty1s's avatar
rusty1s committed
140
]