Commit 1cd8aa6e authored by rusty1s's avatar rusty1s
Browse files

added doc

parent d99469a5
...@@ -2,7 +2,7 @@ from .utils import gen_output ...@@ -2,7 +2,7 @@ from .utils import gen_output
def scatter_add_(output, index, input, dim=0): def scatter_add_(output, index, input, dim=0):
"""Sums all values from the tensor :attr:`input` into :attr:`output` at the """Sums all values from the :attr:`input` tensor into :attr:`output` at the
indices specified in the :attr:`index` tensor along an given axis indices specified in the :attr:`index` tensor along an given axis
:attr:`dim`. For each value in :attr:`input`, its output index is specified :attr:`dim`. For each value in :attr:`input`, its output index is specified
by its index in :attr:`input` for dimensions outside of :attr:`dim` and by by its index in :attr:`input` for dimensions outside of :attr:`dim` and by
...@@ -11,7 +11,7 @@ def scatter_add_(output, index, input, dim=0): ...@@ -11,7 +11,7 @@ def scatter_add_(output, index, input, dim=0):
If :attr:`input` and :attr:`index` are n-dimensional tensors with size If :attr:`input` and :attr:`index` are n-dimensional tensors with size
:math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
:attr:`dim` = i, then :attr:`output` must be an n-dimensional tensor with :attr:`dim` = `i`, then :attr:`output` must be an n-dimensional tensor with
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
values of :attr:`index` must be between `0` and `output.size(dim) - 1`. values of :attr:`index` must be between `0` and `output.size(dim) - 1`.
...@@ -33,10 +33,10 @@ def scatter_add_(output, index, input, dim=0): ...@@ -33,10 +33,10 @@ def scatter_add_(output, index, input, dim=0):
.. testsetup:: .. testsetup::
import torch import torch
from torch_scatter import scatter_add_
.. testcode:: .. testcode::
from torch_scatter import scatter_add_
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = torch.zeros(2, 6) output = torch.zeros(2, 6)
...@@ -53,12 +53,14 @@ def scatter_add_(output, index, input, dim=0): ...@@ -53,12 +53,14 @@ def scatter_add_(output, index, input, dim=0):
def scatter_add(index, input, dim=0, size=None, fill_value=0): def scatter_add(index, input, dim=0, size=None, fill_value=0):
"""Sums all values from the tensor :attr:`input` at the indices specified """Sums all values from the :attr:`input` tensor at the indices specified
in the :attr:`index` tensor along an given axis :attr:`dim`. The output in the :attr:`index` tensor along an given axis :attr:`dim` (`cf.`
size at dimension :attr:`dim` is given by :attr:`size` and must be at least :meth:`~torch_scatter.scatter_add_`).
size `index.max(dim) - 1`. If :attr:`size` is not given, a minimal sized
output tensor is returned. The output tensor is prefilled with the The output size at dimension :attr:`dim` is given by :attr:`size` and must
specified value from :attr:`fill_value`. be at least size `index.max(dim) - 1`. If :attr:`size` is not given, a
minimal sized output tensor is returned. The output tensor is prefilled
with the specified value from :attr:`fill_value`.
For one-dimensional tensors, the operation computes For one-dimensional tensors, the operation computes
...@@ -67,9 +69,6 @@ def scatter_add(index, input, dim=0, size=None, fill_value=0): ...@@ -67,9 +69,6 @@ def scatter_add(index, input, dim=0, size=None, fill_value=0):
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`. where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.
A more detailed explanation is described in
:meth:`~torch_scatter.scatter_add_`.
Args: Args:
index (LongTensor): The indices of elements to scatter index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor input (Tensor): The source tensor
...@@ -82,10 +81,10 @@ def scatter_add(index, input, dim=0, size=None, fill_value=0): ...@@ -82,10 +81,10 @@ def scatter_add(index, input, dim=0, size=None, fill_value=0):
.. testsetup:: .. testsetup::
import torch import torch
from torch_scatter import scatter_add
.. testcode:: .. testcode::
from torch_scatter import scatter_add
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = scatter_add(index, input, dim=1) output = scatter_add(index, input, dim=1)
......
...@@ -2,18 +2,53 @@ from .utils import gen_output ...@@ -2,18 +2,53 @@ from .utils import gen_output
def scatter_sub_(output, index, input, dim=0): def scatter_sub_(output, index, input, dim=0):
"""If multiple indices reference the same location, their **negated """Subtracts all values from the :attr:`input` tensor into :attr:`output`
contributions add**.""" at the indices specified in the :attr:`index` tensor along an given axis
:attr:`dim`. If multiple indices reference the same location, their
**negated contributions add** (`cf.` :meth:`~torch_scatter.scatter_add_`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \mathrm{output}_i - \sum_j \mathrm{input}_j
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.
Args:
output (Tensor): The destination tensor
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_sub_
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = torch.zeros(2, 6)
scatter_sub_(output, index, input, dim=1)
print(output)
.. testoutput::
0 0 -4 -3 -3 0
-2 -4 -4 -0 0 0
[torch.FloatTensor of size 2x6]
"""
return output.scatter_add_(dim, index, -input) return output.scatter_add_(dim, index, -input)
def scatter_sub(index, input, dim=0, size=None, fill_value=0): def scatter_sub(index, input, dim=0, size=None, fill_value=0):
"""Subtracts all values from the tensor :attr:`input` at the indices """Subtracts all values from the :attr:`input` tensor at the indices
specified in the :attr:`index` tensor along an given axis :attr:`dim`. The specified in the :attr:`index` tensor along an given axis :attr:`dim`
output size at dimension :attr:`dim` is given by :attr:`size` and must be (`cf.` :meth:`~torch_scatter.scatter_sub_` and
at least size `index.max(dim) - 1`. If :attr:`size` is not given, a minimal :meth:`~torch_scatter.scatter_add`).
sized output tensor is returned. The output tensor is prefilled with the
specified value from :attr:`fill_value`.
For one-dimensional tensors, the operation computes For one-dimensional tensors, the operation computes
...@@ -22,9 +57,6 @@ def scatter_sub(index, input, dim=0, size=None, fill_value=0): ...@@ -22,9 +57,6 @@ def scatter_sub(index, input, dim=0, size=None, fill_value=0):
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`. where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.
A more detailed explanation is described in
:meth:`~torch_scatter.scatter_sub_`.
Args: Args:
index (LongTensor): The indices of elements to scatter index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor input (Tensor): The source tensor
...@@ -37,10 +69,10 @@ def scatter_sub(index, input, dim=0, size=None, fill_value=0): ...@@ -37,10 +69,10 @@ def scatter_sub(index, input, dim=0, size=None, fill_value=0):
.. testsetup:: .. testsetup::
import torch import torch
from torch_scatter import scatter_sub
.. testcode:: .. testcode::
from torch_scatter import scatter_sub
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = scatter_sub(index, input, dim=1) output = scatter_sub(index, input, dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment