Unverified Commit f7e29388 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #90 from rusty1s/segment

[WIP] segment_* operators
parents 4ceb2d1a d1dd9466
......@@ -11,10 +11,9 @@ def maybe_dim_size(index, dim_size=None):
return index.max().item() + 1 if index.numel() > 0 else 0
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
def broadcast(src, index, dim):
dim = range(src.dim())[dim] # Get real dim value.
# Automatically expand index tensor to the right dimensions.
if index.dim() == 1:
index_size = list(repeat(1, src.dim()))
index_size[dim] = src.size(dim)
......@@ -33,9 +32,17 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
expand_size = []
for s, i in zip(src.size(), index.size()):
expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
src = src.expand(expand_size)
index = index.expand_as(src)
return src, index
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
src, index = broadcast(src, index, dim)
dim = range(src.dim())[dim] # Get real dim value.
# Generate output tensor if not given.
if out is None:
out_size = list(src.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment