__init__.py 4.68 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from .scatter import scatter
rusty1s's avatar
cleaner  
rusty1s committed
2
from .utils import gen_filled_tensor, gen_output
rusty1s's avatar
rusty1s committed
3
4
5


def scatter_add_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
12
13
14
15
16
17
18
    """Sums up all values from the tensor :attr:`input` into :attr:`output` at
    the indices specified in the :attr:`index` tensor along an given axis
    :attr:`dim`. For each value in :attr:`input`, its output index is specified
    by its index in :attr:`input` for dimension != :attr:`dim` and by the
    corresponding value in :attr:`index` for dimension = :attr:`dim`. If
    multiple indices reference the same location, their contributions add.

    If :attr:`input` and :attr:`index` are n-dimensional tensors with equal
    size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
    :attr:`dim` = i, then :attr:`output` must be an n-dimensional tensor with
    size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
    values of :attr:`index` must be between `0` and `output.size(dim) - 1`.

rusty1s's avatar
typo  
rusty1s committed
19
20
    For one-dimensional tensors, the operation computes
    :math:`output_i = output_i + \sum_j input_j`, where sum is over
rusty1s's avatar
typo  
rusty1s committed
21
    :math:`j` such that :math:`index_j = i`.
rusty1s's avatar
rusty1s committed
22

rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    Args:
        output (Tensor): The destination tensor
        index (LongTensor): The indices of elements to scatter
        input (Tensor): The source tensor
        dim (int): The axis along which to index

    Example::
        >> input = torch.Tensor([[2, 0, 1, 4, 3], [0,2, 1, 3, 4]])
        >> index = torch.LongTensor([[4, 5, 2, 3], [0, 0, 2, 2, 1]])
        >> output = torch.zeros(2, 6)
        >> scatter_add_(output, index, input, dim=1)
        0  0  4  3  3  0
        2  4  4  0  0  0
        [torch.FloatTensor of size 2x6]
    """
rusty1s's avatar
rusty1s committed
38
    return output.scatter_add_(dim, index, input)
rusty1s's avatar
rusty1s committed
39
40
41
42
43
44
45
46


def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_add_(output, index, input, dim)


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
47
48
    """If multiple indices reference the same location, their negated
    contributions add."""
rusty1s's avatar
rusty1s committed
49
    return output.scatter_add_(dim, index, -input)
rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
57


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
58
59
    """If multiple indices reference the same location, their
    contributions multiply."""
rusty1s's avatar
rusty1s committed
60
    return scatter('mul', dim, output, index, input)
rusty1s's avatar
rusty1s committed
61
62
63
64
65
66
67
68


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
69
70
    """If multiple indices reference the same location, their
    contributions divide."""
rusty1s's avatar
rusty1s committed
71
    return scatter('div', dim, output, index, input)
rusty1s's avatar
rusty1s committed
72
73
74
75


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
76
77
78
79
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
80
81
    """If multiple indices reference the same location, their
    contributions average."""
rusty1s's avatar
rename  
rusty1s committed
82
83
84
85
    num_output = gen_filled_tensor(output, output.size(), fill_value=0)
    scatter('mean', dim, output, index, input, num_output)
    num_output[num_output == 0] = 1
    output /= num_output
rusty1s's avatar
rusty1s committed
86
87
88
    return output


rusty1s's avatar
cleaner  
rusty1s committed
89
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
rusty1s's avatar
rusty1s committed
90
91
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
92
93


rusty1s's avatar
rusty1s committed
94
def scatter_max_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
95
96
    """If multiple indices reference the same location, the maximal
    contribution gets taken."""
rusty1s's avatar
rename  
rusty1s committed
97
98
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('max', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
99
100
101
102
103
104
105
106


def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_max_(output, index, input, dim)


def scatter_min_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
107
108
    """If multiple indices reference the same location, the minimal
    contribution gets taken."""
rusty1s's avatar
rename  
rusty1s committed
109
110
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('min', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
111
112
113
114
115
116
117


def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_min_(output, index, input, dim)


rusty1s's avatar
rusty1s committed
118
119
__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
120
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
rusty1s's avatar
rusty1s committed
121
122
    'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
    'scatter_min_', 'scatter_min'
rusty1s's avatar
rusty1s committed
123
]