Commit 5acf2355 authored by rusty1s's avatar rusty1s
Browse files

more doctests

parent bc139663
...@@ -48,7 +48,6 @@ def scatter_add_(output, index, input, dim=0): ...@@ -48,7 +48,6 @@ def scatter_add_(output, index, input, dim=0):
0 0 4 3 3 0 0 0 4 3 3 0
2 4 4 0 0 0 2 4 4 0 0 0
[torch.FloatTensor of size 2x6] [torch.FloatTensor of size 2x6]
""" """
return output.scatter_add_(dim, index, input) return output.scatter_add_(dim, index, input)
...@@ -79,6 +78,24 @@ def scatter_add(index, input, dim=0, size=None, fill_value=0): ...@@ -79,6 +78,24 @@ def scatter_add(index, input, dim=0, size=None, fill_value=0):
dim (int, optional): The axis along which to index dim (int, optional): The axis along which to index
size (int, optional): Output size at dimension :attr:`dim` size (int, optional): Output size at dimension :attr:`dim`
fill_value (int, optional): Initial filling of output tensor fill_value (int, optional): Initial filling of output tensor
.. testsetup::
import torch
from torch_scatter import scatter_add
.. testcode::
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = scatter_add(index, input, dim=1)
print(output)
.. testoutput::
0 0 4 3 3 0
2 4 4 0 0 0
[torch.FloatTensor of size 2x6]
""" """
output = gen_output(index, input, dim, size, fill_value) output = gen_output(index, input, dim, size, fill_value)
return scatter_add_(output, index, input, dim) return scatter_add_(output, index, input, dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment