scatter.py 1.21 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
import torch
from torch.autograd import Function

from .._ext import ffi


rusty1s's avatar
rusty1s committed
7
8
def _scatter(name, dim, *data):
    typename = type(data[0]).__name__.replace('Tensor', '')
rusty1s's avatar
rusty1s committed
9
    func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
rusty1s's avatar
rusty1s committed
10
    func(dim, *data)
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18


class _Scatter(Function):
    def __init__(self, name, dim):
        super(_Scatter, self).__init__()
        self.dim = dim
        self.name = name

rusty1s's avatar
rusty1s committed
19
    def forward(self, *data):
rusty1s's avatar
rusty1s committed
20
21
        assert not self.needs_input_grad[1], 'Can\'t differentiate the index'

rusty1s's avatar
rusty1s committed
22
23
        self.mark_dirty(data[0])
        self.save_for_backward(data[1])
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
26
        _scatter(self.name, self.dim, *data)
        return data[0]
rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33
34
35
36

    def backward(self, grad):
        index, = self.saved_variables
        grad_output = grad_input = None

        if self.needs_input_grad[0]:
            grad_output = grad
        if self.needs_input_grad[2]:
            grad_input = grad.gather(self.dim, index.data)

rusty1s's avatar
rusty1s committed
37
38
39
40
        if len(grad) == 3:
            return grad_output, None, grad_input
        else:
            return grad_output, None, grad_input, None
rusty1s's avatar
rusty1s committed
41
42


rusty1s's avatar
rusty1s committed
43
44
45
def scatter(name, dim, *data):
    if torch.is_tensor(data[0]):
        return _scatter(name, dim, *data)
rusty1s's avatar
rusty1s committed
46
    else:
rusty1s's avatar
rusty1s committed
47
        return _Scatter(name, dim)(*data)