scatter.py 2.63 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
from torch.autograd import Function

rusty1s's avatar
rusty1s committed
4
from .ffi import scatter as ffi_scatter
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
class Scatter(Function):
rusty1s's avatar
rusty1s committed
8
    def __init__(self, name, dim):
rusty1s's avatar
rusty1s committed
9
        super(Scatter, self).__init__()
rusty1s's avatar
rusty1s committed
10
        self.name = name
rusty1s's avatar
cleaner  
rusty1s committed
11
        self.dim = dim
rusty1s's avatar
rusty1s committed
12

rusty1s's avatar
rusty1s committed
13
14
15
    def save_for_backward_step(self, *data):
        raise NotImplementedError

rusty1s's avatar
rusty1s committed
16
    def forward(self, *data):
rusty1s's avatar
rusty1s committed
17
18
        assert not self.needs_input_grad[1], 'Can\'t differentiate the index'

rusty1s's avatar
cleaner  
rusty1s committed
19
        self.mark_dirty(data[0])  # Mark output as dirty.
rusty1s's avatar
rusty1s committed
20
        self.len = len(data)  # Save number of arguments for backward step.
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
23
        output = ffi_scatter(self.name, self.dim, *data)
        self.save_for_backward_step(*data)
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
        return output
rusty1s's avatar
rusty1s committed
26

27
    def backward(self, *data):  # pragma: no cover
rusty1s's avatar
rusty1s committed
28
29
30
        grad_output = grad_input = None

        if self.needs_input_grad[0]:
rusty1s's avatar
cleaner  
rusty1s committed
31
            grad_output = data[0]
rusty1s's avatar
rusty1s committed
32

rusty1s's avatar
rusty1s committed
33
34
35
        # Call grad computation of `input` for the specific scatter operation.
        if self.needs_input_grad[2]:
            grad_input = self.backward_step(data[0], *self.saved_variables)
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
        # Return and fill with empty grads for non-differentiable arguments.
rusty1s's avatar
cleaner  
rusty1s committed
38
        return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
41
42
    def backward_step(self, *data):
        raise NotImplementedError

rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
def scatter(Clx, name, dim, *data):
rusty1s's avatar
rusty1s committed
45
    if torch.is_tensor(data[0]):
rusty1s's avatar
rusty1s committed
46
        return ffi_scatter(name, dim, *data)
rusty1s's avatar
rusty1s committed
47
    else:
rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        return Clx(dim)(*data)


# def index_backward(dim, index, grad, arg):  # pragma: no cover
#     typename = type(grad).__name__.replace('Tensor', '')
#     cuda = 'cuda_' if grad.is_cuda else ''
#     func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
#     output = grad.new(index.size()).fill_(0)
#     func(dim, output, index, grad, arg)
#     return output

# def _scatter_backward(name, dim, saved, *data):
#     # saved = (index, ), (index, arg) or (index, count)

#     print(name)
#     print(len(data))
#     print(len(saved))
#     print(saved[1].size())
#     # data = (grad, )
#     # index, = seved
#     if has_arg(name):
#         return index_backward(dim, saved[0].data, data[0], saved[1].data)

#     if has_count(name):
#         return (data[0] / saved[1]).gather(dim, saved[0].data)
#     # Different grad computation of `input` if `scatter_max` or
#     # `scatter_min` was used.
#     # if self.needs_input_grad[2] and not has_arg(self.name):
#     #     index, = self.saved_variables
#     #     grad_input = data[0].gather(self.dim, index.data)

#     # if self.needs_input_grad[2] and has_arg(self.name):
#     #     index, arg = self.saved_variables
#     #     data = (index.data, data[0], arg.data)
#     grad_input = index_backward(self.dim, *data)