__init__.py 2.93 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from .scatter import scatter
rusty1s's avatar
cleaner  
rusty1s committed
2
from .utils import gen_filled_tensor, gen_output
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
from .add import scatter_add_, scatter_add
rusty1s's avatar
rusty1s committed
5
6
7


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
8
9
    """If multiple indices reference the same location, their negated
    contributions add."""
rusty1s's avatar
rusty1s committed
10
    return output.scatter_add_(dim, index, -input)
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
19
20
    """If multiple indices reference the same location, their
    contributions multiply."""
rusty1s's avatar
rusty1s committed
21
    return scatter('mul', dim, output, index, input)
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
30
31
    """If multiple indices reference the same location, their
    contributions divide."""
rusty1s's avatar
rusty1s committed
32
    return scatter('div', dim, output, index, input)
rusty1s's avatar
rusty1s committed
33
34
35
36


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
37
38
39
40
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
41
42
    """If multiple indices reference the same location, their
    contributions average."""
rusty1s's avatar
rename  
rusty1s committed
43
44
45
46
    num_output = gen_filled_tensor(output, output.size(), fill_value=0)
    scatter('mean', dim, output, index, input, num_output)
    num_output[num_output == 0] = 1
    output /= num_output
rusty1s's avatar
rusty1s committed
47
48
49
    return output


rusty1s's avatar
cleaner  
rusty1s committed
50
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
rusty1s's avatar
rusty1s committed
51
52
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
53
54


rusty1s's avatar
rusty1s committed
55
def scatter_max_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
56
57
    """If multiple indices reference the same location, the maximal
    contribution gets taken."""
rusty1s's avatar
rename  
rusty1s committed
58
59
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('max', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
60
61
62
63
64
65
66
67


def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_max_(output, index, input, dim)


def scatter_min_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
68
69
    """If multiple indices reference the same location, the minimal
    contribution gets taken."""
rusty1s's avatar
rename  
rusty1s committed
70
71
    arg_output = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('min', dim, output, index, input, arg_output)
rusty1s's avatar
rusty1s committed
72
73
74
75
76
77
78


def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_min_(output, index, input, dim)


rusty1s's avatar
rusty1s committed
79
80
__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
81
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
rusty1s's avatar
rusty1s committed
82
83
    'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
    'scatter_min_', 'scatter_min'
rusty1s's avatar
rusty1s committed
84
]