Commit f16e422a authored by rusty1s's avatar rusty1s
Browse files

py 2.7 fix

parent 96971292
...@@ -6,14 +6,14 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -6,14 +6,14 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
# Automatically expand index tensor to the right dimensions. # Automatically expand index tensor to the right dimensions.
if index.dim() == 1: if index.dim() == 1:
index_size = [*repeat(1, src.dim())] index_size = list(repeat(1, src.dim()))
index_size[dim] = src.size(dim) index_size[dim] = src.size(dim)
index = index.view(index_size).expand_as(src) index = index.view(index_size).expand_as(src)
# Generate output tensor if not given. # Generate output tensor if not given.
if out is None: if out is None:
dim_size = index.max() + 1 if dim_size is None else dim_size dim_size = index.max() + 1 if dim_size is None else dim_size
out_size = [*src.size()] out_size = list(src.size())
out_size[dim] = dim_size out_size[dim] = dim_size
out = src.new_full(out_size, fill_value) out = src.new_full(out_size, fill_value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment