gen.py 1.47 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from __future__ import division

rusty1s's avatar
rusty1s committed
3
4
from itertools import repeat

rusty1s's avatar
rusty1s committed
5
6
import torch

rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
def maybe_dim_size(index, dim_size=None):
    if dim_size is not None:
        return dim_size
    return index.max().item() + 1 if index.numel() > 0 else 0


rusty1s's avatar
rusty1s committed
14
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
rusty1s's avatar
rusty1s committed
15
16
    dim = range(src.dim())[dim]  # Get real dim value.

rusty1s's avatar
rusty1s committed
17
18
    # Automatically expand index tensor to the right dimensions.
    if index.dim() == 1:
rusty1s's avatar
rusty1s committed
19
        index_size = list(repeat(1, src.dim()))
rusty1s's avatar
rusty1s committed
20
        index_size[dim] = src.size(dim)
rusty1s's avatar
rusty1s committed
21
22
23
24
        if index.numel() > 0:
            index = index.view(index_size).expand_as(src)
        else:  # PyTorch has a bug when view is used on zero-element tensors.
            index = src.new_empty(index_size, dtype=torch.long)
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
27
28
29
30
31
32
33
34
35
36
37
    # Broadcasting capabilties: Expand dimensions to match.
    if src.dim() != index.dim():
        raise ValueError(
            ('Number of dimensions of src and index tensor do not match, '
             'got {} and {}').format(src.dim(), index.dim()))

    expand_size = []
    for s, i in zip(src.size(), index.size()):
        expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
    src = src.expand(expand_size)
    index = index.expand_as(src)

rusty1s's avatar
rusty1s committed
38
39
    # Generate output tensor if not given.
    if out is None:
rusty1s's avatar
rusty1s committed
40
        out_size = list(src.size())
rusty1s's avatar
rusty1s committed
41
        dim_size = maybe_dim_size(index, dim_size)
rusty1s's avatar
rusty1s committed
42
43
44
        out_size[dim] = dim_size
        out = src.new_full(out_size, fill_value)

rusty1s's avatar
rusty1s committed
45
    return src, out, index, dim