__init__.py 4.53 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
from .scatter import scatter
rusty1s's avatar
cleaner  
rusty1s committed
4
from .utils import gen_filled_tensor, gen_output
rusty1s's avatar
rusty1s committed
5
6
7


def scatter_add_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    """Sums up all values from the tensor :attr:`input` into :attr:`output` at
    the indices specified in the :attr:`index` tensor along an given axis
    :attr:`dim`. For each value in :attr:`input`, its output index is specified
    by its index in :attr:`input` for dimension != :attr:`dim` and by the
    corresponding value in :attr:`index` for dimension = :attr:`dim`. If
    multiple indices reference the same location, their contributions add.

    If :attr:`input` and :attr:`index` are n-dimensional tensors with equal
    size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
    :attr:`dim` = i, then :attr:`output` must be an n-dimensional tensor with
    size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
    values of :attr:`index` must be between `0` and `output.size(dim) - 1`.

    Args:
        output (Tensor): The destination tensor
        index (LongTensor): The indices of elements to scatter
        input (Tensor): The source tensor
        dim (int): The axis along which to index

    Example::
        >> input = torch.Tensor([[2, 0, 1, 4, 3], [0,2, 1, 3, 4]])
        >> index = torch.LongTensor([[4, 5, 2, 3], [0, 0, 2, 2, 1]])
        >> output = torch.zeros(2, 6)
        >> scatter_add_(output, index, input, dim=1)
        0  0  4  3  3  0
        2  4  4  0  0  0
        [torch.FloatTensor of size 2x6]
    """
rusty1s's avatar
rusty1s committed
36
    return output.scatter_add_(dim, index, input)
rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
43
44


def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_add_(output, index, input, dim)


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
45
46
    """If multiple indices reference the same location, their negated
    contributions add."""
rusty1s's avatar
rusty1s committed
47
    return output.scatter_add_(dim, index, -input)
rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
54
55


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
56
57
    """If multiple indices reference the same location, their
    contributions multiply."""
rusty1s's avatar
rusty1s committed
58
    return scatter('mul', dim, output, index, input)
rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
65
66


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
67
68
    """If multiple indices reference the same location, their
    contributions divide."""
rusty1s's avatar
rusty1s committed
69
    return scatter('div', dim, output, index, input)
rusty1s's avatar
rusty1s committed
70
71
72
73


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
74
75
76
77
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
78
79
    """If multiple indices reference the same location, their
    contributions average."""
rusty1s's avatar
rename  
rusty1s committed
80
81
82
83
    num_output = gen_filled_tensor(output, output.size(), fill_value=0)
    scatter('mean', dim, output, index, input, num_output)
    num_output[num_output == 0] = 1
    output /= num_output
rusty1s's avatar
rusty1s committed
84
85
86
    return output


rusty1s's avatar
cleaner  
rusty1s committed
87
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
rusty1s's avatar
rusty1s committed
88
89
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
90
91


rusty1s's avatar
rusty1s committed
92
def scatter_max_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
93
94
    """If multiple indices reference the same location, the maximal
    contribution gets taken."""
rusty1s's avatar
rename  
rusty1s committed
95
96
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('max', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
97
98
99
100
101
102
103
104


def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_max_(output, index, input, dim)


def scatter_min_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
105
106
    """If multiple indices reference the same location, the minimal
    contribution gets taken."""
rusty1s's avatar
rename  
rusty1s committed
107
108
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('min', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
109
110
111
112
113
114
115


def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_min_(output, index, input, dim)


rusty1s's avatar
rusty1s committed
116
117
__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
118
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
rusty1s's avatar
rusty1s committed
119
120
    'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
    'scatter_min_', 'scatter_min'
rusty1s's avatar
rusty1s committed
121
]