scatter.py 2.19 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from itertools import chain

rusty1s's avatar
rusty1s committed
3
4
5
6
7
8
import torch
from torch.autograd import Function

from .._ext import ffi


rusty1s's avatar
rusty1s committed
9
def _scatter(name, dim, *data):
rusty1s's avatar
rusty1s committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    a, b, c = data[:3]

    # Assert same dimensionality across all inputs.
    assert dim >= 0 and dim < a.dim(), 'Index dimension is out of bounds'
    assert b.dim() == c.dim(), ('Index tensor must have same dimensions as '
                                'input tensor')
    assert a.dim() == c.dim(), ('Input tensor must have same dimensions as '
                                'output tensor')

    # Assert same tensor length across index and input.
    assert b.numel() == c.numel(), ('Index tensor must have same size as '
                                    'input tensor')

    # Assert same tensor sizes across input and output apart from `dim`.
    for d in chain(range(dim), range(dim + 1, a.dim())):
        assert a.size(d) == c.size(d), (
            'Input tensor must have same size as output tensor apart from the '
            'specified dimension')

rusty1s's avatar
rusty1s committed
29
    typename = type(data[0]).__name__.replace('Tensor', '')
rusty1s's avatar
rusty1s committed
30
    func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
rusty1s's avatar
rusty1s committed
31
    func(dim, *data)
rusty1s's avatar
rusty1s committed
32
33
34
35
36
37


class _Scatter(Function):
    def __init__(self, name, dim):
        super(_Scatter, self).__init__()
        self.name = name
rusty1s's avatar
cleaner  
rusty1s committed
38
        self.dim = dim
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
    def forward(self, *data):
rusty1s's avatar
rusty1s committed
41
42
        assert not self.needs_input_grad[1], 'Can\'t differentiate the index'

rusty1s's avatar
cleaner  
rusty1s committed
43
44
45
        self.mark_dirty(data[0])  # Mark output as dirty.
        self.len = len(data)  # Save number of arguments for backward step
        self.save_for_backward(data[1])  # Save index for backward step.
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
48
        _scatter(self.name, self.dim, *data)
        return data[0]
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
cleaner  
rusty1s committed
50
    def backward(self, *data):
rusty1s's avatar
rusty1s committed
51
52
53
54
        index, = self.saved_variables
        grad_output = grad_input = None

        if self.needs_input_grad[0]:
rusty1s's avatar
cleaner  
rusty1s committed
55
            grad_output = data[0]
rusty1s's avatar
rusty1s committed
56
        if self.needs_input_grad[2]:
rusty1s's avatar
rusty1s committed
57
            # TODO: max and min
rusty1s's avatar
cleaner  
rusty1s committed
58
            grad_input = data[0].gather(self.dim, index.data)
rusty1s's avatar
rusty1s committed
59

rusty1s's avatar
cleaner  
rusty1s committed
60
        return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
rusty1s's avatar
rusty1s committed
61
62


rusty1s's avatar
rusty1s committed
63
64
65
def scatter(name, dim, *data):
    if torch.is_tensor(data[0]):
        return _scatter(name, dim, *data)
rusty1s's avatar
rusty1s committed
66
    else:
rusty1s's avatar
rusty1s committed
67
        return _Scatter(name, dim)(*data)