Commit 5e69a349 authored by rusty1s's avatar rusty1s
Browse files

only execute kernels if size(dim) > 0

parent a8e3e285
...@@ -89,4 +89,6 @@ def scatter_div(src, index, dim=-1, out=None, dim_size=None, fill_value=1): ...@@ -89,4 +89,6 @@ def scatter_div(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
[0.5000, 0.2500, 0.5000, 1.0000, 1.0000, 1.0000]]) [0.5000, 0.2500, 0.5000, 1.0000, 1.0000, 1.0000]])
""" """
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterDiv.apply(out, src, index, dim) return ScatterDiv.apply(out, src, index, dim)
...@@ -96,4 +96,6 @@ def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -96,4 +96,6 @@ def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
[ 1, 4, 3, -1, -1, -1]]) [ 1, 4, 3, -1, -1, -1]])
""" """
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterMax.apply(out, src, index, dim) return ScatterMax.apply(out, src, index, dim)
...@@ -96,4 +96,6 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -96,4 +96,6 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
[ 1, 4, 3, -1, -1, -1]]) [ 1, 4, 3, -1, -1, -1]])
""" """
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterMin.apply(out, src, index, dim) return ScatterMin.apply(out, src, index, dim)
...@@ -88,4 +88,6 @@ def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1): ...@@ -88,4 +88,6 @@ def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
[6, 4, 8, 1, 1, 1]]) [6, 4, 8, 1, 1, 1]])
""" """
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterMul.apply(out, src, index, dim) return ScatterMul.apply(out, src, index, dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment