Commit 21478411 authored by rusty1s's avatar rusty1s
Browse files

dim size fix for index.numel() == 0

parent 50e4214b
from itertools import repeat
def maybe_dim_size(index, dim_size=None):
if dim_size is not None:
return dim_size
return index.max().item() + 1 if index.numel() > 0 else 0
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
dim = range(src.dim())[dim] # Get real dim value.
......@@ -12,8 +18,8 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
# Generate output tensor if not given.
if out is None:
dim_size = index.max().item() + 1 if dim_size is None else dim_size
out_size = list(src.size())
dim_size = maybe_dim_size(index, dim_size)
out_size[dim] = dim_size
out = src.new_full(out_size, fill_value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment