scatter.py 3.46 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from itertools import chain

rusty1s's avatar
rusty1s committed
3
4
5
6
7
8
import torch
from torch.autograd import Function

from .._ext import ffi


rusty1s's avatar
rename  
rusty1s committed
9
def has_arg_output(name):
rusty1s's avatar
rusty1s committed
10
11
12
    return name in ['max', 'min']


rusty1s's avatar
rusty1s committed
13
def _scatter(name, dim, *data):
rusty1s's avatar
rusty1s committed
14
15
    a, b, c = data[:3]

rusty1s's avatar
typo  
rusty1s committed
16
    # Assert index dimension is valid.
rusty1s's avatar
rusty1s committed
17
    assert dim >= 0 and dim < a.dim(), 'Index dimension is out of bounds'
rusty1s's avatar
typo  
rusty1s committed
18
19

    # Assert same dimensionality across all inputs.
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    assert b.dim() == c.dim(), ('Index tensor must have same dimensions as '
                                'input tensor')
    assert a.dim() == c.dim(), ('Input tensor must have same dimensions as '
                                'output tensor')

    # Assert same tensor length across index and input.
    assert b.numel() == c.numel(), ('Index tensor must have same size as '
                                    'input tensor')

    # Assert same tensor sizes across input and output apart from `dim`.
    for d in chain(range(dim), range(dim + 1, a.dim())):
        assert a.size(d) == c.size(d), (
            'Input tensor must have same size as output tensor apart from the '
            'specified dimension')

rusty1s's avatar
rusty1s committed
35
    typename = type(data[0]).__name__.replace('Tensor', '')
rusty1s's avatar
rusty1s committed
36
37
    cuda = 'cuda_' if data[0].is_cuda else ''
    func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
rusty1s's avatar
rusty1s committed
38
    func(dim, *data)
rusty1s's avatar
rename  
rusty1s committed
39
    return (data[0], data[3]) if has_arg_output(name) else data[0]
rusty1s's avatar
rusty1s committed
40
41


rusty1s's avatar
rename  
rusty1s committed
42
def index_backward(dim, index, grad, arg_grad):
rusty1s's avatar
rusty1s committed
43
    typename = type(grad).__name__.replace('Tensor', '')
rusty1s's avatar
rusty1s committed
44
45
    cuda = 'cuda_' if grad.is_cuda else ''
    func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
rusty1s's avatar
rusty1s committed
46
    output = grad.new(index.size()).fill_(0)
rusty1s's avatar
rename  
rusty1s committed
47
    func(dim, output, index, grad, arg_grad)
rusty1s's avatar
rusty1s committed
48
    return output
rusty1s's avatar
rusty1s committed
49
50
51
52
53
54


class _Scatter(Function):
    def __init__(self, name, dim):
        super(_Scatter, self).__init__()
        self.name = name
rusty1s's avatar
cleaner  
rusty1s committed
55
        self.dim = dim
rusty1s's avatar
rusty1s committed
56

rusty1s's avatar
rusty1s committed
57
    def forward(self, *data):
rusty1s's avatar
rusty1s committed
58
59
        assert not self.needs_input_grad[1], 'Can\'t differentiate the index'

rusty1s's avatar
cleaner  
rusty1s committed
60
        self.mark_dirty(data[0])  # Mark output as dirty.
rusty1s's avatar
rusty1s committed
61
        self.len = len(data)  # Save number of arguments for backward step.
rusty1s's avatar
rusty1s committed
62

rusty1s's avatar
rusty1s committed
63
        _scatter(self.name, self.dim, *data)
rusty1s's avatar
rusty1s committed
64

rusty1s's avatar
rusty1s committed
65
66
        # `scatter_min` and `scatter_max` additionally return the `argmax`
        # respectively `argmin`. In addition, we need to save the
rusty1s's avatar
rename  
rusty1s committed
67
68
        # `arg_output` for the backward pass.
        if has_arg_output(self.name):
rusty1s's avatar
rusty1s committed
69
70
71
72
73
            self.save_for_backward(data[1], data[3])
            return data[0], data[3]
        else:
            self.save_for_backward(data[1])
            return data[0]
rusty1s's avatar
rusty1s committed
74

rusty1s's avatar
cleaner  
rusty1s committed
75
    def backward(self, *data):
rusty1s's avatar
rusty1s committed
76
77
78
        grad_output = grad_input = None

        if self.needs_input_grad[0]:
rusty1s's avatar
cleaner  
rusty1s committed
79
            grad_output = data[0]
rusty1s's avatar
rusty1s committed
80

rusty1s's avatar
rusty1s committed
81
82
        # Different grad computation of `input` if `scatter_max` or
        # `scatter_min` was used.
rusty1s's avatar
rename  
rusty1s committed
83
        if self.needs_input_grad[2] and not has_arg_output(self.name):
rusty1s's avatar
rusty1s committed
84
            index, = self.saved_variables
rusty1s's avatar
cleaner  
rusty1s committed
85
            grad_input = data[0].gather(self.dim, index.data)
rusty1s's avatar
rusty1s committed
86

rusty1s's avatar
rename  
rusty1s committed
87
        if self.needs_input_grad[2] and has_arg_output(self.name):
rusty1s's avatar
rename  
rusty1s committed
88
89
            index, arg_grad = self.saved_variables
            data = (index.data, data[0], arg_grad.data)
rusty1s's avatar
typo  
rusty1s committed
90
            grad_input = index_backward(self.dim, *data)
rusty1s's avatar
rusty1s committed
91

rusty1s's avatar
rusty1s committed
92
93
        # Return and fill with empty grads for none-differentiable passed
        # arguments in forward pass.
rusty1s's avatar
cleaner  
rusty1s committed
94
        return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
rusty1s's avatar
rusty1s committed
95
96


rusty1s's avatar
rusty1s committed
97
98
99
def scatter(name, dim, *data):
    if torch.is_tensor(data[0]):
        return _scatter(name, dim, *data)
rusty1s's avatar
rusty1s committed
100
    else:
rusty1s's avatar
rusty1s committed
101
        return _Scatter(name, dim)(*data)