Commit 5e69a349 authored by rusty1s's avatar rusty1s
Browse files

only execute kernels if size(dim) > 0

parent a8e3e285
......@@ -89,4 +89,6 @@ def scatter_div(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
[0.5000, 0.2500, 0.5000, 1.0000, 1.0000, 1.0000]])
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterDiv.apply(out, src, index, dim)
......@@ -96,4 +96,6 @@ def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
[ 1, 4, 3, -1, -1, -1]])
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterMax.apply(out, src, index, dim)
......@@ -96,4 +96,6 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
[ 1, 4, 3, -1, -1, -1]])
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterMin.apply(out, src, index, dim)
......@@ -88,4 +88,6 @@ def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
[6, 4, 8, 1, 1, 1]])
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterMul.apply(out, src, index, dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment