__init__.py 2 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
from torch.autograd import Variable

rusty1s's avatar
rusty1s committed
4
5
6
7
8
from .scatter import scatter
from .utils import gen_output


def scatter_add_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
9
10
    scatter('add', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18


def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_add_(output, index, input, dim)


def scatter_sub_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
19
20
    scatter('sub', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
27
28


def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_sub_(output, index, input, dim)


def scatter_mul_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
29
30
    scatter('mul', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
38


def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mul_(output, index, input, dim)


def scatter_div_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
39
40
    scatter('div', dim, output, index, input)
    return output
rusty1s's avatar
rusty1s committed
41
42
43
44


def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
rusty1s's avatar
rusty1s committed
45
46
47
48
    scatter_div_(output, index, input, dim)


def scatter_mean_(output, index, input, dim=0):
rusty1s's avatar
rusty1s committed
49
50
51
52
    if torch.is_tensor(input):
        output_count = output.new(output.size()).fill_(0)
    else:
        output_count = Variable(output.data.new(output.size()).fill_(0))
rusty1s's avatar
rusty1s committed
53
    scatter('mean', dim, output, index, input, output_count)
rusty1s's avatar
rusty1s committed
54
    output_count[output_count == 0] = 1
rusty1s's avatar
rusty1s committed
55
56
57
58
59
60
61
    output /= output_count
    return output


def scatter_mean(index, input, dim=0, max_index=None, fill_value=1):
    output = gen_output(index, input, dim, max_index, fill_value)
    return scatter_mean_(output, index, input, dim)
rusty1s's avatar
rusty1s committed
62
63
64
65


__all__ = [
    'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
rusty1s's avatar
rusty1s committed
66
67
    'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
    'scatter_mean_', 'scatter_mean'
rusty1s's avatar
rusty1s committed
68
]