"src/diffusers/pipelines/sana/pipeline_sana_controlnet.py" did not exist on "a5720e9e3124753c85b2260dec5f39d75ce18245"
gen.py 665 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import repeat


def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
rusty1s's avatar
rusty1s committed
5
6
    dim = range(src.dim())[dim]  # Get real dim value.

rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
    # Automatically expand index tensor to the right dimensions.
    if index.dim() == 1:
        index_size = [*repeat(1, src.dim())]
        index_size[dim] = src.size(dim)
        index = index.view(index_size).expand_as(src)

    # Generate output tensor if not given.
    if out is None:
        dim_size = index.max() + 1 if dim_size is None else dim_size
        out_size = [*src.size()]
        out_size[dim] = dim_size
        out = src.new_full(out_size, fill_value)

rusty1s's avatar
rusty1s committed
20
    return src, out, index, dim