Commit c541e366 authored by rusty1s's avatar rusty1s
Browse files

cleaner

parent 8c48dee0
......@@ -12,10 +12,5 @@ def gen_filled_tensor(input, size, fill_value):
def gen_output(index, input, dim, max_index, fill_value):
max_index = index.max() + 1 if max_index is None else max_index
size = list(index.size())
if torch.is_tensor(input):
size[dim] = max_index
else:
size[dim] = max_index.data[0]
size[dim] = max_index if torch.is_tensor(input) else max_index.data[0]
return gen_filled_tensor(input, torch.Size(size), fill_value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment