ffi.py 1.64 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from itertools import chain

from .._ext import ffi


def scatter(name, dim, *data):
    # data = output, index, input, additional data
    a, b, c = data[:3]

    # Assert index dimension is valid.
    assert dim >= 0 and dim < b.dim(), 'Index dimension is out of bounds'

    # Assert same dimensionality across all inputs.
    assert b.dim() == c.dim(), ('Index tensor must have same dimensions as '
                                'input tensor')
    assert a.dim() == c.dim(), ('Input tensor must have same dimensions as '
                                'output tensor')

    # Assert same tensor length across index and input.
    assert b.numel() == c.numel(), ('Index tensor must have same size as '
                                    'input tensor')

    # Assert same tensor sizes across input and output apart from `dim`.
    for d in chain(range(dim), range(dim + 1, a.dim())):
        assert a.size(d) == c.size(d), (
            'Input tensor must have same size as output tensor apart from the '
            'specified dimension')

    typename = type(data[0]).__name__.replace('Tensor', '')
    cuda = 'cuda_' if data[0].is_cuda else ''
    func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
    func(dim, *data)

    if len(data) <= 3:
        return data[0]

    return (data[0], ) + tuple(data[3:])


def index_backward(dim, index, grad, arg):  # pragma: no cover
    typename = type(grad).__name__.replace('Tensor', '')
    cuda = 'cuda_' if grad.is_cuda else ''
    func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
    output = grad.new(index.size()).fill_(0)
    func(dim, output, index, grad, arg)
    return output