from itertools import repeat def maybe_dim_size(index, dim_size=None): if dim_size is not None: return dim_size return index.max().item() + 1 if index.numel() > 0 else 0 def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0): dim = range(src.dim())[dim] # Get real dim value. # Automatically expand index tensor to the right dimensions. if index.dim() == 1: index_size = list(repeat(1, src.dim())) index_size[dim] = src.size(dim) index = index.view(index_size).expand_as(src) # Generate output tensor if not given. if out is None: out_size = list(src.size()) dim_size = maybe_dim_size(index, dim_size) out_size[dim] = dim_size out = src.new_full(out_size, fill_value) return src, out, index, dim