__init__.py 2.98 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from .scatter import scatter
rusty1s's avatar
cleaner  
rusty1s committed
2
from .utils import gen_filled_tensor, gen_output
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
from .add import scatter_add_, scatter_add
rusty1s's avatar
rusty1s committed
5
6
7


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
8
9
    """If multiple indices reference the same location, their negated
    contributions add."""
rusty1s's avatar
rusty1s committed
10
    return output.scatter_add_(dim, index, -input)
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
19
20
    """If multiple indices reference the same location, their
    contributions multiply."""
rusty1s's avatar
rusty1s committed
21
    return scatter('mul', dim, output, index, input)
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
30
31
    """If multiple indices reference the same location, their
    contributions divide."""
rusty1s's avatar
rusty1s committed
32
    return scatter('div', dim, output, index, input)
rusty1s's avatar
rusty1s committed
33
34
35
36


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
37
38
39
40
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
41
42
    """If multiple indices reference the same location, their
    contributions average."""
rusty1s's avatar
rename  
rusty1s committed
43
44
45
46
    num_output = gen_filled_tensor(output, output.size(), fill_value=0)
    scatter('mean', dim, output, index, input, num_output)
    num_output[num_output == 0] = 1
    output /= num_output
rusty1s's avatar
rusty1s committed
47
48
49
    return output


rusty1s's avatar
cleaner  
rusty1s committed
50
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
rusty1s's avatar
rusty1s committed
51
52
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
53
54


rusty1s's avatar
rusty1s committed
55
def scatter_max_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
56
    """If multiple indices reference the same location, the maximal
rusty1s's avatar
rusty1s committed
57
58
59
60
    contribution gets taken.

    :rtype: (:class:`Tensor`, :class:`LongTensor`)
    """
rusty1s's avatar
rename  
rusty1s committed
61
62
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('max', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
63
64
65
66
67
68
69
70


def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_max_(output, index, input, dim)


def scatter_min_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
71
72
    """If multiple indices reference the same location, the minimal
    contribution gets taken."""
rusty1s's avatar
rename  
rusty1s committed
73
74
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('min', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
75
76
77
78
79
80
81


def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_min_(output, index, input, dim)


rusty1s's avatar
rusty1s committed
82
83
__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
84
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
rusty1s's avatar
rusty1s committed
85
86
    'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
    'scatter_min_', 'scatter_min'
rusty1s's avatar
rusty1s committed
87
]