__init__.py 2.94 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
from torch.autograd import Variable

rusty1s's avatar
rusty1s committed
4
5
6
7
8
from .scatter import scatter
from .utils import gen_output


def scatter_add_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
9
    return scatter('add', dim, output, index, input)
rusty1s's avatar
rusty1s committed
10
11
12
13
14
15
16
17


def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_add_(output, index, input, dim)


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
18
    return scatter('sub', dim, output, index, input)
rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25
26


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
27
    return scatter('mul', dim, output, index, input)
rusty1s's avatar
rusty1s committed
28
29
30
31
32
33
34
35


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
36
    return scatter('div', dim, output, index, input)
rusty1s's avatar
rusty1s committed
37
38
39
40


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
41
42
43
44
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
45
46
47
48
    if torch.is_tensor(input):
        output_count = output.new(output.size()).fill_(0)
    else:
        output_count = Variable(output.data.new(output.size()).fill_(0))
rusty1s's avatar
rusty1s committed
49
    scatter('mean', dim, output, index, input, output_count)
rusty1s's avatar
rusty1s committed
50
    output_count[output_count == 0] = 1
rusty1s's avatar
rusty1s committed
51
52
53
54
    output /= output_count
    return output


rusty1s's avatar
cleaner  
rusty1s committed
55
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
rusty1s's avatar
rusty1s committed
56
57
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
58
59


rusty1s's avatar
rusty1s committed
60
def scatter_max_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
61
62
63
64
    if torch.is_tensor(input):
        output_index = index.new(output.size()).fill_(-1)
    else:
        output_index = Variable(index.data.new(output.size()).fill_(-1))
rusty1s's avatar
rusty1s committed
65
    return scatter('max', dim, output, index, input, output_index)
rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
72
73


def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_max_(output, index, input, dim)


def scatter_min_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
74
75
76
77
    if torch.is_tensor(input):
        output_index = index.new(output.size()).fill_(-1)
    else:
        output_index = Variable(index.data.new(output.size()).fill_(-1))
rusty1s's avatar
rusty1s committed
78
    return scatter('min', dim, output, index, input, output_index)
rusty1s's avatar
rusty1s committed
79
80
81
82
83
84
85


def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_min_(output, index, input, dim)


rusty1s's avatar
rusty1s committed
86
87
__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
88
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
rusty1s's avatar
rusty1s committed
89
90
    'scatter_mean_', 'scatter_mean', 'scatter_max_', 'scatter_max',
    'scatter_min_', 'scatter_min'
rusty1s's avatar
rusty1s committed
91
]