__init__.py 2.63 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from .scatter import scatter
rusty1s's avatar
cleaner  
rusty1s committed
2
from .utils import gen_filled_tensor, gen_output
rusty1s's avatar
rusty1s committed
3
4
5


def scatter_add_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
6
    return scatter('add', dim, output, index, input)
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14


def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_add_(output, index, input, dim)


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
15
    return scatter('sub', dim, output, index, input)
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
23


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
24
    return scatter('mul', dim, output, index, input)
rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
31
32


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
33
    return scatter('div', dim, output, index, input)
rusty1s's avatar
rusty1s committed
34
35
36
37


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
38
39
40
41
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
rusty1s's avatar
cleaner  
rusty1s committed
42
    output_count = gen_filled_tensor(output, output.size(), fill_value=0)
rusty1s's avatar
rusty1s committed
43
    scatter('mean', dim, output, index, input, output_count)
rusty1s's avatar
rusty1s committed
44
    output_count[output_count == 0] = 1
rusty1s's avatar
rusty1s committed
45
46
47
48
    output /= output_count
    return output


rusty1s's avatar
cleaner  
rusty1s committed
49
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
rusty1s's avatar
rusty1s committed
50
51
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
52
53


rusty1s's avatar
rusty1s committed
54
def scatter_max_(output, index, input, dim=0):
rusty1s's avatar
cleaner  
rusty1s committed
55
    output_index = gen_filled_tensor(index, output.size(), fill_value=-1)
rusty1s's avatar
rusty1s committed
56
    return scatter('max', dim, output, index, input, output_index)
rusty1s's avatar
rusty1s committed
57
58
59
60
61
62
63
64


def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_max_(output, index, input, dim)


def scatter_min_(output, index, input, dim=0):
rusty1s's avatar
cleaner  
rusty1s committed
65
    output_index = gen_filled_tensor(index, output.size(), fill_value=-1)
rusty1s's avatar
rusty1s committed
66
    return scatter('min', dim, output, index, input, output_index)
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
73


def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_min_(output, index, input, dim)


rusty1s's avatar
rusty1s committed
74
75
__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
76
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
rusty1s's avatar
rusty1s committed
77
78
    'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
    'scatter_min_', 'scatter_min'
rusty1s's avatar
rusty1s committed
79
]