Commit 5acf2355 authored by rusty1s's avatar rusty1s
Browse files

more doctests

parent bc139663
......@@ -48,7 +48,6 @@ def scatter_add_(output, index, input, dim=0):
0 0 4 3 3 0
2 4 4 0 0 0
[torch.FloatTensor of size 2x6]
"""
return output.scatter_add_(dim, index, input)
......@@ -79,6 +78,24 @@ def scatter_add(index, input, dim=0, size=None, fill_value=0):
dim (int, optional): The axis along which to index
size (int, optional): Output size at dimension :attr:`dim`
fill_value (int, optional): Initial filling of output tensor
.. testsetup::
import torch
from torch_scatter import scatter_add
.. testcode::
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = scatter_add(index, input, dim=1)
print(output)
.. testoutput::
0 0 4 3 3 0
2 4 4 0 0 0
[torch.FloatTensor of size 2x6]
"""
output = gen_output(index, input, dim, size, fill_value)
return scatter_add_(output, index, input, dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment