gen.py 1002 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from itertools import repeat

rusty1s's avatar
rusty1s committed
3
4
import torch

rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
def maybe_dim_size(index, dim_size=None):
    if dim_size is not None:
        return dim_size
    return index.max().item() + 1 if index.numel() > 0 else 0


rusty1s's avatar
rusty1s committed
12
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
rusty1s's avatar
rusty1s committed
13
14
    dim = range(src.dim())[dim]  # Get real dim value.

rusty1s's avatar
rusty1s committed
15
16
    # Automatically expand index tensor to the right dimensions.
    if index.dim() == 1:
rusty1s's avatar
rusty1s committed
17
        index_size = list(repeat(1, src.dim()))
rusty1s's avatar
rusty1s committed
18
        index_size[dim] = src.size(dim)
rusty1s's avatar
rusty1s committed
19
20
21
22
        if index.numel() > 0:
            index = index.view(index_size).expand_as(src)
        else:  # PyTorch has a bug when view is used on zero-element tensors.
            index = src.new_empty(index_size, dtype=torch.long)
rusty1s's avatar
rusty1s committed
23
24
25

    # Generate output tensor if not given.
    if out is None:
rusty1s's avatar
rusty1s committed
26
        out_size = list(src.size())
rusty1s's avatar
rusty1s committed
27
        dim_size = maybe_dim_size(index, dim_size)
rusty1s's avatar
rusty1s committed
28
29
30
        out_size[dim] = dim_size
        out = src.new_full(out_size, fill_value)

rusty1s's avatar
rusty1s committed
31
    return src, out, index, dim