gen.py 599 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from itertools import repeat


def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
    # Automatically expand index tensor to the right dimensions.
    if index.dim() == 1:
        index_size = [*repeat(1, src.dim())]
        index_size[dim] = src.size(dim)
        index = index.view(index_size).expand_as(src)

    # Generate output tensor if not given.
    if out is None:
        dim_size = index.max() + 1 if dim_size is None else dim_size
        out_size = [*src.size()]
        out_size[dim] = dim_size
        out = src.new_full(out_size, fill_value)

    return out, index