test_add.py 1.55 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from nose.tools import assert_equal

import torch
rusty1s's avatar
rusty1s committed
4
from torch.autograd import Variable
rusty1s's avatar
rename  
rusty1s committed
5
from torch_scatter._ext import ffi
rusty1s's avatar
rusty1s committed
6
7


rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class ScatterAdd(torch.autograd.Function):
    def __init__(self, dim):
        super(ScatterAdd, self).__init__()
        self.dim = dim

    def forward(self, output, index, input):
        assert not self.needs_input_grad[1], 'Can\'t differentiate the index'

        self.mark_dirty(output)
        self.save_for_backward(index)

        ffi.scatter_add_Float(output, index, input, self.dim)
        return output

    def backward(self, grad):
        index, = self.saved_variables
        grad_output = grad_input = None

        if self.needs_input_grad[0]:
            grad_output = grad
        if self.needs_input_grad[2]:
            grad_input = grad.gather(self.dim, index.data)

        return grad_output, None, grad_input


rusty1s's avatar
rusty1s committed
34
def test_scatter_add():
rusty1s's avatar
rusty1s committed
35
36
37
38
39
40
41
    input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
    index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
    input = torch.FloatTensor(input)
    index = torch.LongTensor(index)
    output = input.new(2, 6).fill_(0)
    expected_output = [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]

rusty1s's avatar
rename  
rusty1s committed
42
    ffi.scatter_add_Float(output, index, input, 1)
rusty1s's avatar
rusty1s committed
43
    assert_equal(output.tolist(), expected_output)
rusty1s's avatar
rusty1s committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

    output = Variable(output)
    index = Variable(index)
    input = Variable(input, requires_grad=True)

    # a = input * 2
    # b = output * 2
    a = input * 2
    b = output * 2
    ScatterAdd(1)(b, index, a)
    # b.scatter_add_(1, index, a)

    c = b.sum()
    c.backward()

    print(input.grad)
    print(output.grad)