__init__.py 3.04 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
from torch.autograd import Variable

rusty1s's avatar
rusty1s committed
4
5
6
7
8
from .scatter import scatter
from .utils import gen_output


def scatter_add_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
9
10
    scatter('add', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18


def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_add_(output, index, input, dim)


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
19
20
    scatter('sub', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
27
28


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
29
30
    scatter('mul', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
38


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
39
40
    scatter('div', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
41
42
43
44


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
45
46
47
48
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
49
50
51
52
    if torch.is_tensor(input):
        output_count = output.new(output.size()).fill_(0)
    else:
        output_count = Variable(output.data.new(output.size()).fill_(0))
rusty1s's avatar
rusty1s committed
53
    scatter('mean', dim, output, index, input, output_count)
rusty1s's avatar
rusty1s committed
54
    output_count[output_count == 0] = 1
rusty1s's avatar
rusty1s committed
55
56
57
58
    output /= output_count
    return output


rusty1s's avatar
cleaner  
rusty1s committed
59
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
rusty1s's avatar
rusty1s committed
60
61
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
62
63


rusty1s's avatar
rusty1s committed
64
def scatter_max_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
65
66
67
68
    if torch.is_tensor(input):
        output_index = index.new(output.size()).fill_(-1)
    else:
        output_index = Variable(index.data.new(output.size()).fill_(-1))
rusty1s's avatar
rusty1s committed
69
70
71
72
73
74
75
76
77
78
    scatter('max', dim, output, index, input, output_index)
    return output, output_index


def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_max_(output, index, input, dim)


def scatter_min_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
79
80
81
82
    if torch.is_tensor(input):
        output_index = index.new(output.size()).fill_(-1)
    else:
        output_index = Variable(index.data.new(output.size()).fill_(-1))
rusty1s's avatar
rusty1s committed
83
84
85
86
87
88
89
90
91
    scatter('min', dim, output, index, input, output_index)
    return output, output_index


def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_min_(output, index, input, dim)


rusty1s's avatar
rusty1s committed
92
93
__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
94
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
rusty1s's avatar
rusty1s committed
95
96
    'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
    'scatter_min_', 'scatter_min'
rusty1s's avatar
rusty1s committed
97
]