Unverified Commit f7e29388 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #90 from rusty1s/segment

[WIP] segment_* operators
parents 4ceb2d1a d1dd9466
...@@ -11,10 +11,9 @@ def maybe_dim_size(index, dim_size=None): ...@@ -11,10 +11,9 @@ def maybe_dim_size(index, dim_size=None):
return index.max().item() + 1 if index.numel() > 0 else 0 return index.max().item() + 1 if index.numel() > 0 else 0
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0): def broadcast(src, index, dim):
dim = range(src.dim())[dim] # Get real dim value. dim = range(src.dim())[dim] # Get real dim value.
# Automatically expand index tensor to the right dimensions.
if index.dim() == 1: if index.dim() == 1:
index_size = list(repeat(1, src.dim())) index_size = list(repeat(1, src.dim()))
index_size[dim] = src.size(dim) index_size[dim] = src.size(dim)
...@@ -33,9 +32,17 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -33,9 +32,17 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
expand_size = [] expand_size = []
for s, i in zip(src.size(), index.size()): for s, i in zip(src.size(), index.size()):
expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)] expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
src = src.expand(expand_size) src = src.expand(expand_size)
index = index.expand_as(src) index = index.expand_as(src)
return src, index
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
src, index = broadcast(src, index, dim)
dim = range(src.dim())[dim] # Get real dim value.
# Generate output tensor if not given. # Generate output tensor if not given.
if out is None: if out is None:
out_size = list(src.size()) out_size = list(src.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment