scatter.py 1.43 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
from torch.autograd import Function

rusty1s's avatar
rusty1s committed
4
from .ffi import scatter as ffi_scatter
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
class Scatter(Function):
rusty1s's avatar
rusty1s committed
8
    def __init__(self, name, dim):
rusty1s's avatar
rusty1s committed
9
        super(Scatter, self).__init__()
rusty1s's avatar
rusty1s committed
10
        self.name = name
rusty1s's avatar
cleaner  
rusty1s committed
11
        self.dim = dim
rusty1s's avatar
rusty1s committed
12

rusty1s's avatar
rusty1s committed
13
    def save_for_backward_step(self, *data):  # pragma: no cover
rusty1s's avatar
rusty1s committed
14
15
        raise NotImplementedError

rusty1s's avatar
rusty1s committed
16
    def forward(self, *data):
rusty1s's avatar
rusty1s committed
17
18
        assert not self.needs_input_grad[1], 'Can\'t differentiate the index'

rusty1s's avatar
cleaner  
rusty1s committed
19
        self.mark_dirty(data[0])  # Mark output as dirty.
rusty1s's avatar
rusty1s committed
20
        self.len = len(data)  # Save number of arguments for backward step.
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
23
        output = ffi_scatter(self.name, self.dim, *data)
        self.save_for_backward_step(*data)
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
        return output
rusty1s's avatar
rusty1s committed
26

27
    def backward(self, *data):  # pragma: no cover
rusty1s's avatar
rusty1s committed
28
29
30
        grad_output = grad_input = None

        if self.needs_input_grad[0]:
rusty1s's avatar
cleaner  
rusty1s committed
31
            grad_output = data[0]
rusty1s's avatar
rusty1s committed
32

rusty1s's avatar
rusty1s committed
33
34
35
        # Call grad computation of `input` for the specific scatter operation.
        if self.needs_input_grad[2]:
            grad_input = self.backward_step(data[0], *self.saved_variables)
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
        # Return and fill with empty grads for non-differentiable arguments.
rusty1s's avatar
cleaner  
rusty1s committed
38
        return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
    def backward_step(self, *data):  # pragma: no cover
rusty1s's avatar
rusty1s committed
41
42
        raise NotImplementedError

rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
def scatter(Clx, name, dim, *data):
rusty1s's avatar
rusty1s committed
45
    if torch.is_tensor(data[0]):
rusty1s's avatar
rusty1s committed
46
        return ffi_scatter(name, dim, *data)
rusty1s's avatar
rusty1s committed
47
    else:
rusty1s's avatar
rusty1s committed
48
        return Clx(dim)(*data)