Commit 3922eca6 authored by rusty1s's avatar rusty1s
Browse files

Update doctest

parent 4a119480
...@@ -66,8 +66,8 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -66,8 +66,8 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
.. testoutput:: .. testoutput::
tensor([[0, 0, 4, 3, 3, 0], tensor([[0., 0., 4., 3., 3., 0.],
[2, 4, 4, 0, 0, 0]]) [2., 4., 4., 0., 0., 0.]])
""" """
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
return out.scatter_add_(dim, index, src) return out.scatter_add_(dim, index, src)
...@@ -93,8 +93,8 @@ def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None): ...@@ -93,8 +93,8 @@ def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
.. testoutput:: .. testoutput::
tensor([[0, 0, 4, 3, 2, 0], tensor([[0., 0., 4., 3., 2., 0.],
[2, 4, 3, 0, 0, 0]]) [2., 4., 3., 0., 0., 0.]])
tensor([[-1, -1, 3, 4, 0, 1], tensor([[-1, -1, 3, 4, 0, 1],
[ 1, 4, 3, -1, -1, -1]]) [ 1, 4, 3, -1, -1, -1]])
""" """
......
...@@ -95,8 +95,8 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=None): ...@@ -95,8 +95,8 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
.. testoutput:: .. testoutput::
tensor([[ 0, 0, -4, -3, -2, 0], tensor([[ 0., 0., -4., -3., -2., 0.],
[-2, -4, -3, 0, 0, 0]]) [-2., -4., -3., 0., 0., 0.]])
tensor([[-1, -1, 3, 4, 0, 1], tensor([[-1, -1, 3, 4, 0, 1],
[ 1, 4, 3, -1, -1, -1]]) [ 1, 4, 3, -1, -1, -1]])
""" """
......
...@@ -84,8 +84,8 @@ def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1): ...@@ -84,8 +84,8 @@ def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
.. testoutput:: .. testoutput::
tensor([[1, 1, 4, 3, 6, 0], tensor([[1., 1., 4., 3., 6., 0.],
[6, 4, 8, 1, 1, 1]]) [6., 4., 8., 1., 1., 1.]])
""" """
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value) src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover if src.size(dim) == 0: # pragma: no cover
......
...@@ -58,7 +58,7 @@ def scatter_sub(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -58,7 +58,7 @@ def scatter_sub(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
.. testoutput:: .. testoutput::
tensor([[ 0, 0, -4, -3, -3, 0], tensor([[ 0., 0., -4., -3., -3., 0.],
[-2, -4, -4, 0, 0, 0]]) [-2., -4., -4., 0., 0., 0.]])
""" """
return scatter_add(src.neg(), index, dim, out, dim_size, fill_value) return scatter_add(src.neg(), index, dim, out, dim_size, fill_value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment