scatter.py 1.31 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
import torch
from torch.autograd import Function

from .._ext import ffi


rusty1s's avatar
rusty1s committed
7
8
def _scatter(name, dim, *data):
    typename = type(data[0]).__name__.replace('Tensor', '')
rusty1s's avatar
rusty1s committed
9
    func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
rusty1s's avatar
rusty1s committed
10
    func(dim, *data)
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16


class _Scatter(Function):
    def __init__(self, name, dim):
        super(_Scatter, self).__init__()
        self.name = name
rusty1s's avatar
cleaner  
rusty1s committed
17
        self.dim = dim
rusty1s's avatar
rusty1s committed
18

rusty1s's avatar
rusty1s committed
19
    def forward(self, *data):
rusty1s's avatar
rusty1s committed
20
21
        assert not self.needs_input_grad[1], 'Can\'t differentiate the index'

rusty1s's avatar
cleaner  
rusty1s committed
22
23
24
        self.mark_dirty(data[0])  # Mark output as dirty.
        self.len = len(data)  # Save number of arguments for backward step
        self.save_for_backward(data[1])  # Save index for backward step.
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
27
        _scatter(self.name, self.dim, *data)
        return data[0]
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
cleaner  
rusty1s committed
29
    def backward(self, *data):
rusty1s's avatar
rusty1s committed
30
31
32
33
        index, = self.saved_variables
        grad_output = grad_input = None

        if self.needs_input_grad[0]:
rusty1s's avatar
cleaner  
rusty1s committed
34
            grad_output = data[0]
rusty1s's avatar
rusty1s committed
35
        if self.needs_input_grad[2]:
rusty1s's avatar
rusty1s committed
36
            # TODO: max and min
rusty1s's avatar
cleaner  
rusty1s committed
37
            grad_input = data[0].gather(self.dim, index.data)
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
cleaner  
rusty1s committed
39
        return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
rusty1s's avatar
rusty1s committed
40
41


rusty1s's avatar
rusty1s committed
42
43
44
def scatter(name, dim, *data):
    if torch.is_tensor(data[0]):
        return _scatter(name, dim, *data)
rusty1s's avatar
rusty1s committed
45
    else:
rusty1s's avatar
rusty1s committed
46
        return _Scatter(name, dim)(*data)