__init__.py 1.82 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
from .scatter import scatter
from .utils import gen_output


def scatter_add_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
6
7
    scatter('add', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15


def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_add_(output, index, input, dim)


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
16
17
    scatter('sub', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
24
25


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
26
27
    scatter('mul', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
28
29
30
31
32
33
34
35


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
36
37
    scatter('div', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
38
39
40
41


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
    output_count = output.new(output.size()).fill_(0)
    scatter('mean', dim, output, index, input, output_count)
    output /= output_count
    output[output != output] = 0
    return output


def scatter_mean(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
56
57
58
59


__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
60
61
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
    'scatter_mean_', 'scatter_mean'
rusty1s's avatar
rusty1s committed
62
]