gen.py 812 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from itertools import repeat


rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
def maybe_dim_size(index, dim_size=None):
    if dim_size is not None:
        return dim_size
    return index.max().item() + 1 if index.numel() > 0 else 0


rusty1s's avatar
rusty1s committed
10
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
rusty1s's avatar
rusty1s committed
11
12
    dim = range(src.dim())[dim]  # Get real dim value.

rusty1s's avatar
rusty1s committed
13
14
    # Automatically expand index tensor to the right dimensions.
    if index.dim() == 1:
rusty1s's avatar
rusty1s committed
15
        index_size = list(repeat(1, src.dim()))
rusty1s's avatar
rusty1s committed
16
17
18
19
20
        index_size[dim] = src.size(dim)
        index = index.view(index_size).expand_as(src)

    # Generate output tensor if not given.
    if out is None:
rusty1s's avatar
rusty1s committed
21
        out_size = list(src.size())
rusty1s's avatar
rusty1s committed
22
        dim_size = maybe_dim_size(index, dim_size)
rusty1s's avatar
rusty1s committed
23
24
25
        out_size[dim] = dim_size
        out = src.new_full(out_size, fill_value)

rusty1s's avatar
rusty1s committed
26
    return src, out, index, dim