Commit 356d0fe8 authored by rusty1s's avatar rusty1s
Browse files

update doc

parent 99db5b80
...@@ -39,6 +39,7 @@ install: ...@@ -39,6 +39,7 @@ install:
- pip install codecov - pip install codecov
- pip install sphinx - pip install sphinx
- pip install sphinx_rtd_theme - pip install sphinx_rtd_theme
- pip install sphinx-autodoc-typehints
script: script:
- python -c "import torch; print(torch.__version__)" - python -c "import torch; print(torch.__version__)"
- pycodestyle . - pycodestyle .
......
...@@ -22,22 +22,17 @@ ...@@ -22,22 +22,17 @@
**[Documentation](https://pytorch-scatter.readthedocs.io)** **[Documentation](https://pytorch-scatter.readthedocs.io)**
This package consists of a small extension library of highly optimized sparse update (scatter/segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package. This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
The package consists of the following operations: Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
* [**Scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html) The package consists of the following operations with reduction types `"sum"|"mean"|"min"|"max"`:
* [**SegmentCOO**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html)
* [**SegmentCSR**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html)
In addition, we provide composite functions which make use of `scatter_*` operations under the hood: * [**scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment.html) based on arbitrary indices
* [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices
* [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_std) All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_logsumexp)
* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax)
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)
All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable via `@torch.jit.script`.
## Installation ## Installation
......
...@@ -3,3 +3,4 @@ numpy ...@@ -3,3 +3,4 @@ numpy
torch_nightly torch_nightly
sphinx sphinx
sphinx_rtd_theme sphinx_rtd_theme
sphinx-autodoc-typehints
Scatter Softmax
===============
.. automodule:: torch_scatter.composite
:noindex:
.. autofunction:: scatter_softmax
.. autofunction:: scatter_log_softmax
...@@ -11,6 +11,7 @@ extensions = [ ...@@ -11,6 +11,7 @@ extensions = [
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
'sphinx.ext.githubpages', 'sphinx.ext.githubpages',
'sphinx_autodoc_typehints',
] ]
source_suffix = '.rst' source_suffix = '.rst'
......
Scatter Add Scatter
=========== =======
.. automodule:: torch_scatter .. automodule:: torch_scatter
:noindex: :noindex:
.. autofunction:: scatter_add .. autofunction:: scatter
Scatter Div
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_div
Scatter LogSumExp
=================
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_logsumexp
Scatter Max
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_max
Scatter Mean
============
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_mean
Scatter Min
===========
.. automodule:: torch_scatter
.. autofunction:: scatter_min
Scatter Mul
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_mul
Scatter Std
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_std
Scatter Sub
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_sub
...@@ -7,7 +7,7 @@ This package consists of a small extension library of highly optimized sparse up ...@@ -7,7 +7,7 @@ This package consists of a small extension library of highly optimized sparse up
Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements. Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations. All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.
.. toctree:: .. toctree::
:glob: :glob:
...@@ -15,7 +15,6 @@ All included operations are broadcastable, work on varying data types, and are i ...@@ -15,7 +15,6 @@ All included operations are broadcastable, work on varying data types, and are i
:caption: Package reference :caption: Package reference
functions/* functions/*
composite/*
Indices and tables Indices and tables
================== ==================
......
...@@ -9,7 +9,10 @@ from torch.utils.cpp_extension import BuildExtension ...@@ -9,7 +9,10 @@ from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
WITH_CUDA = WITH_CUDA or os.getenv('FORCE_CUDA', '0') == '1' if os.getenv('FORCE_CUDA', '0') == '1':
WITH_CUDA = True
if os.getenv('FORCE_NON_CUDA', '0') == '1':
WITH_CUDA = False
def get_extensions(): def get_extensions():
......
...@@ -44,7 +44,6 @@ def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -44,7 +44,6 @@ def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
@torch.jit.script
def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor: reduce: str = "sum") -> torch.Tensor:
...@@ -58,64 +57,68 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -58,64 +57,68 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
| |
Sums all values from the :attr:`src` tensor into :attr:`out` at the indices Reduces all values from the :attr:`src` tensor into :attr:`out` at the
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For indices specified in the :attr:`index` tensor along a given axis
each value in :attr:`src`, its output index is specified by its index in :attr:`dim`.
:attr:`src` for dimensions outside of :attr:`dim` and by the For each value in :attr:`src`, its output index is specified by its index
corresponding value in :attr:`index` for dimension :attr:`dim`. If in :attr:`src` for dimensions outside of :attr:`dim` and by the
multiple indices reference the same location, their **contributions add**. corresponding value in :attr:`index` for dimension :attr:`dim`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`index` are n-dimensional tensors with
size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
:attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
values of :attr:`index` must be between `0` and `out.size(dim) - 1`. tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions Moreover, the values of :attr:`index` must be between :math:`0` and
do not match. :math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do
For one-dimensional tensors, the operation computes not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math:: .. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`. :math:`\mathrm{index}_j = i`.
Args: .. note::
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter. This operation is implemented via atomic operations on the GPU and is
dim (int, optional): The axis along which to index. therefore **non-deterministic** since the order of parallel operations
(default: :obj:`-1`) to the same value is undetermined.
out (Tensor, optional): The destination tensor. (default: :obj:`None`) For floating-point variables, this results in a source of variance in
dim_size (int, optional): If :attr:`out` is not given, automatically the result.
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is :param src: The source tensor.
returned. (default: :obj:`None`) :param index: The indices of elements to scatter.
fill_value (int, optional): If :attr:`out` is not given, automatically :param dim: The axis along which to index. (default: :obj:`-1`)
fill output tensor with :attr:`fill_value`. (default: :obj:`0`) :param out: The destination tensor.
:param dim_size: If :attr:`out` is not given, automatically create output
with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`,
:obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
.. testsetup:: .. code-block:: python
import torch
.. testcode::
from torch_scatter import scatter_add from torch_scatter import scatter
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) src = torch.randn(10, 6, 64)
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) index = torch.tensor([0, 1, 0, 1, 2, 1])
out = src.new_zeros((2, 6))
out = scatter_add(src, index, out=out) # Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")
print(out) print(out.size())
.. testoutput:: .. code-block::
tensor([[0., 0., 4., 3., 3., 0.], torch.Size([10, 3, 64])
[2., 4., 4., 0., 0., 0.]])
""" """
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size) return scatter_sum(src, index, dim, out, dim_size)
......
...@@ -44,7 +44,6 @@ def segment_max_coo(src: torch.Tensor, index: torch.Tensor, ...@@ -44,7 +44,6 @@ def segment_max_coo(src: torch.Tensor, index: torch.Tensor,
return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size) return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size)
@torch.jit.script
def segment_coo(src: torch.Tensor, index: torch.Tensor, def segment_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None, dim_size: Optional[int] = None,
...@@ -77,6 +76,7 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor, ...@@ -77,6 +76,7 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor,
:math:`y - 1` in ascending order. :math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`. not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes computes
...@@ -91,10 +91,6 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor, ...@@ -91,10 +91,6 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor,
Due to the use of sorted indices, :meth:`segment_coo` is usually faster Due to the use of sorted indices, :meth:`segment_coo` is usually faster
than the more general :meth:`scatter` operation. than the more general :meth:`scatter` operation.
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
second tensor representing the :obj:`argmin` and :obj:`argmax`,
respectively.
.. note:: .. note::
This operation is implemented via atomic operations on the GPU and is This operation is implemented via atomic operations on the GPU and is
...@@ -103,23 +99,19 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor, ...@@ -103,23 +99,19 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor,
For floating-point variables, this results in a source of variance in For floating-point variables, this results in a source of variance in
the result. the result.
Args: :param src: The source tensor.
src (Tensor): The source tensor. :param index: The sorted indices of elements to segment.
index (LongTensor): The sorted indices of elements to segment. The number of dimensions of :attr:`index` needs to be less than or
The number of dimensions of :attr:`index` needs to be less than or equal to :attr:`src`.
equal to :attr:`src`. :param out: The destination tensor.
out (Tensor, optional): The destination tensor. (default: :obj:`None`) :param dim_size: If :attr:`out` is not given, automatically create output
dim_size (int, optional): If :attr:`out` is not given, automatically with size :attr:`dim_size` at dimension :obj:`index.dim() - 1`.
create output with size :attr:`dim_size` at dimension If :attr:`dim_size` is not given, a minimal sized output tensor
:obj:`index.dim() - 1`. according to :obj:`index.max() + 1` is returned.
If :attr:`dim_size` is not given, a minimal sized output tensor :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`,
according to :obj:`index.max() + 1` is returned. :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
(default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"sum"`, :rtype: :class:`Tensor`
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"sum"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
.. code-block:: python .. code-block:: python
......
...@@ -39,7 +39,6 @@ def segment_max_csr(src: torch.Tensor, indptr: torch.Tensor, ...@@ -39,7 +39,6 @@ def segment_max_csr(src: torch.Tensor, indptr: torch.Tensor,
return torch.ops.torch_scatter.segment_max_csr(src, indptr, out) return torch.ops.torch_scatter.segment_max_csr(src, indptr, out)
@torch.jit.script
def segment_csr(src: torch.Tensor, indptr: torch.Tensor, def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
reduce: str = "sum") -> torch.Tensor: reduce: str = "sum") -> torch.Tensor:
...@@ -63,6 +62,7 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor, ...@@ -63,6 +62,7 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
:math:`x_m` in ascending order. :math:`x_m` in ascending order.
The :attr:`indptr` tensor supports broadcasting in case its dimensions do The :attr:`indptr` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`. not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes computes
...@@ -73,26 +73,20 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor, ...@@ -73,26 +73,20 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
Due to the use of index pointers, :meth:`segment_csr` is the fastest Due to the use of index pointers, :meth:`segment_csr` is the fastest
method to apply for grouped reductions. method to apply for grouped reductions.
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
second tensor representing the :obj:`argmin` and :obj:`argmax`,
respectively.
.. note:: .. note::
In contrast to :meth:`scatter()` and :meth:`segment_coo`, this In contrast to :meth:`scatter()` and :meth:`segment_coo`, this
operation is **fully-deterministic**. operation is **fully-deterministic**.
Args: :param src: The source tensor.
src (Tensor): The source tensor. :param indptr: The index pointers between elements to segment.
indptr (LongTensor): The index pointers between elements to segment. The number of dimensions of :attr:`index` needs to be less than or
The number of dimensions of :attr:`index` needs to be less than or equal to :attr:`src`.
equal to :attr:`src`. :param out: The destination tensor.
out (Tensor, optional): The destination tensor. (default: :obj:`None`) :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`,
reduce (string, optional): The reduce operation (:obj:`"sum"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"sum"`) :rtype: :class:`Tensor`
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
.. code-block:: python .. code-block:: python
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment