__init__.py 2.98 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from .scatter import scatter
rusty1s's avatar
cleaner  
rusty1s committed
2
from .utils import gen_filled_tensor, gen_output
rusty1s's avatar
rusty1s committed
3
4
5


def scatter_add_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
6
7
    """If multiple indices reference the same location, their contributions
    add."""
rusty1s's avatar
rusty1s committed
8
    return scatter('add', dim, output, index, input)
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16


def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_add_(output, index, input, dim)


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
17
18
    """If multiple indices reference the same location, their negated
    contributions add."""
rusty1s's avatar
rusty1s committed
19
    return scatter('sub', dim, output, index, input)
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
28
29
    """If multiple indices reference the same location, their
    contributions multiply."""
rusty1s's avatar
rusty1s committed
30
    return scatter('mul', dim, output, index, input)
rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
38


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
doc  
rusty1s committed
39
40
    """If multiple indices reference the same location, their
    contributions divide."""
rusty1s's avatar
rusty1s committed
41
    return scatter('div', dim, output, index, input)
rusty1s's avatar
rusty1s committed
42
43
44
45


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
46
47
48
49
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
rusty1s's avatar
cleaner  
rusty1s committed
50
    output_count = gen_filled_tensor(output, output.size(), fill_value=0)
rusty1s's avatar
rusty1s committed
51
    scatter('mean', dim, output, index, input, output_count)
rusty1s's avatar
rusty1s committed
52
    output_count[output_count == 0] = 1
rusty1s's avatar
rusty1s committed
53
54
55
56
    output /= output_count
    return output


rusty1s's avatar
cleaner  
rusty1s committed
57
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
rusty1s's avatar
rusty1s committed
58
59
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
60
61


rusty1s's avatar
rusty1s committed
62
def scatter_max_(output, index, input, dim=0):
rusty1s's avatar
renames  
rusty1s committed
63
64
    output_arg = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('max', dim, output, index, input, output_arg)
rusty1s's avatar
rusty1s committed
65
66
67
68
69
70
71
72


def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_max_(output, index, input, dim)


def scatter_min_(output, index, input, dim=0):
rusty1s's avatar
renames  
rusty1s committed
73
74
    output_arg = gen_filled_tensor(index, output.size(), fill_value=-1)
    return scatter('min', dim, output, index, input, output_arg)
rusty1s's avatar
rusty1s committed
75
76
77
78
79
80
81


def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_min_(output, index, input, dim)


rusty1s's avatar
rusty1s committed
82
83
__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
84
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
rusty1s's avatar
rusty1s committed
85
86
    'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
    'scatter_min_', 'scatter_min'
rusty1s's avatar
rusty1s committed
87
]