scatter.py 3.39 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from itertools import chain

rusty1s's avatar
rusty1s committed
3
4
5
6
7
8
import torch
from torch.autograd import Function

from .._ext import ffi


rusty1s's avatar
rusty1s committed
9
10
11
12
def _has_output_index(name):
    return name in ['max', 'min']


rusty1s's avatar
rusty1s committed
13
def _scatter(name, dim, *data):
rusty1s's avatar
rusty1s committed
14
15
    a, b, c = data[:3]

rusty1s's avatar
typo  
rusty1s committed
16
    # Assert index dimension is valid.
rusty1s's avatar
rusty1s committed
17
    assert dim >= 0 and dim < a.dim(), 'Index dimension is out of bounds'
rusty1s's avatar
typo  
rusty1s committed
18
19

    # Assert same dimensionality across all inputs.
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    assert b.dim() == c.dim(), ('Index tensor must have same dimensions as '
                                'input tensor')
    assert a.dim() == c.dim(), ('Input tensor must have same dimensions as '
                                'output tensor')

    # Assert same tensor length across index and input.
    assert b.numel() == c.numel(), ('Index tensor must have same size as '
                                    'input tensor')

    # Assert same tensor sizes across input and output apart from `dim`.
    for d in chain(range(dim), range(dim + 1, a.dim())):
        assert a.size(d) == c.size(d), (
            'Input tensor must have same size as output tensor apart from the '
            'specified dimension')

rusty1s's avatar
rusty1s committed
35
    typename = type(data[0]).__name__.replace('Tensor', '')
rusty1s's avatar
rusty1s committed
36
    func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
rusty1s's avatar
rusty1s committed
37
    func(dim, *data)
rusty1s's avatar
rusty1s committed
38
39
40
41
42
43
44
45
46
    return (data[0], data[3]) if _has_output_index(name) else data[0]


def _index_backward(dim, index, grad, grad_index):
    typename = type(grad).__name__.replace('Tensor', '')
    func = getattr(ffi, 'index_backward_{}'.format(typename))
    output = grad.new(index.size()).fill_(0)
    func(dim, output, index, grad, grad_index)
    return output
rusty1s's avatar
rusty1s committed
47
48
49
50
51
52


class _Scatter(Function):
    def __init__(self, name, dim):
        super(_Scatter, self).__init__()
        self.name = name
rusty1s's avatar
cleaner  
rusty1s committed
53
        self.dim = dim
rusty1s's avatar
rusty1s committed
54

rusty1s's avatar
rusty1s committed
55
    def forward(self, *data):
rusty1s's avatar
rusty1s committed
56
57
        assert not self.needs_input_grad[1], 'Can\'t differentiate the index'

rusty1s's avatar
cleaner  
rusty1s committed
58
        self.mark_dirty(data[0])  # Mark output as dirty.
rusty1s's avatar
rusty1s committed
59
        self.len = len(data)  # Save number of arguments for backward step.
rusty1s's avatar
rusty1s committed
60

rusty1s's avatar
rusty1s committed
61
        _scatter(self.name, self.dim, *data)
rusty1s's avatar
rusty1s committed
62

rusty1s's avatar
rusty1s committed
63
64
65
        # `scatter_min` and `scatter_max` additionally return the `argmax`
        # respectively `argmin`. In addition, we need to save the
        # `output_index` for the backward pass.
rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
        if _has_output_index(self.name):
            self.save_for_backward(data[1], data[3])
            return data[0], data[3]
        else:
            self.save_for_backward(data[1])
            return data[0]
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
cleaner  
rusty1s committed
73
    def backward(self, *data):
rusty1s's avatar
rusty1s committed
74
75
76
        grad_output = grad_input = None

        if self.needs_input_grad[0]:
rusty1s's avatar
cleaner  
rusty1s committed
77
            grad_output = data[0]
rusty1s's avatar
rusty1s committed
78

rusty1s's avatar
rusty1s committed
79
80
        # Different grad computation of `input` if `scatter_max` or
        # `scatter_min` was used.
rusty1s's avatar
rusty1s committed
81
82
        if self.needs_input_grad[2] and not _has_output_index(self.name):
            index, = self.saved_variables
rusty1s's avatar
cleaner  
rusty1s committed
83
            grad_input = data[0].gather(self.dim, index.data)
rusty1s's avatar
rusty1s committed
84

rusty1s's avatar
rusty1s committed
85
86
87
88
89
        if self.needs_input_grad[2] and _has_output_index(self.name):
            index, grad_index = self.saved_variables
            data = (index.data, data[0], grad_index.data)
            grad_input = _index_backward(self.dim, *data)

rusty1s's avatar
rusty1s committed
90
91
        # Return and fill with empty grads for none-differentiable passed
        # arguments in forward pass.
rusty1s's avatar
cleaner  
rusty1s committed
92
        return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
rusty1s's avatar
rusty1s committed
93
94


rusty1s's avatar
rusty1s committed
95
96
97
def scatter(name, dim, *data):
    if torch.is_tensor(data[0]):
        return _scatter(name, dim, *data)
rusty1s's avatar
rusty1s committed
98
    else:
rusty1s's avatar
rusty1s committed
99
        return _Scatter(name, dim)(*data)