Commit 62c61224 authored by rusty1s's avatar rusty1s
Browse files

clean up code base / added new functions to readme / added docs for softmax functions

parent d63eb9c9
...@@ -34,6 +34,12 @@ The package consists of the following operations: ...@@ -34,6 +34,12 @@ The package consists of the following operations:
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html) * [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html)
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html) * [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html) * [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html)
In addition, we provide composite functions which make use of `scatter_*` operations under the hood:
* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax)
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations. All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
......
Scatter Softmax
===============
.. automodule:: torch_scatter.composite
:noindex:
.. autofunction:: scatter_softmax
.. autofunction:: scatter_log_softmax
...@@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i ...@@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i
:caption: Package reference :caption: Package reference
functions/* functions/*
composite/*
Indices and tables Indices and tables
================== ==================
......
from itertools import product
import pytest
import torch
from torch_scatter.composite import scatter_log_softmax, scatter_softmax
from test.utils import devices, tensor, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_softmax(src, index)
out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1)
out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
dim=-1)
expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)
assert torch.allclose(out, expected)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_log_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_log_softmax(src, index)
out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1)
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
dim=-1)
expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)
assert torch.allclose(out, expected)
...@@ -4,27 +4,21 @@ import torch ...@@ -4,27 +4,21 @@ import torch
import pytest import pytest
from torch_scatter import scatter_logsumexp from torch_scatter import scatter_logsumexp
from .utils import devices, tensor from .utils import devices, tensor, grad_dtypes
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
@pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
def test_logsumexp(dtype, device): def test_logsumexp(dtype, device):
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_logsumexp(src, index) out = scatter_logsumexp(src, index)
idx0 = torch.logsumexp( out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1)
torch.tensor([0.5, 0.5], dtype=dtype), out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
dim=-1).tolist() out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1)
idx1 = torch.logsumexp( out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
torch.tensor([0, -2.1, 3.2], dtype=dtype), out4 = torch.tensor(-1, dtype=dtype)
dim=-1).tolist()
idx2 = 7 # Single element
idx3 = torch.finfo(dtype).min # Empty index, returns yield value
idx4 = -1 # logsumexp with -inf is the identity
assert out.tolist() == [idx0, idx1, idx2, idx3, idx4] expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
assert torch.allclose(out, expected)
from itertools import product
import numpy as np
import pytest
import torch
from torch_scatter.composite import scatter_log_softmax, scatter_softmax
from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
def test_log_softmax(dtype, device):
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')],
dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_log_softmax(src, index)
# Expected results per index
idx0 = [np.log(0.5), np.log(0.5)]
idx1 = torch.log_softmax(
torch.tensor([0.0, -2.1, 3.2], dtype=dtype),
dim=-1).tolist()
idx2 = 0.0 # Single element, has logprob=0
# index=3 is empty. Should not matter.
idx4 = [0.0, float('-inf')] # log_softmax with -inf preserves the -inf
np.testing.assert_allclose(
out.tolist(),
[idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]],
rtol=1e-05, atol=1e-10
)
@pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
def test_softmax(dtype, device):
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')],
dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_softmax(src, index)
# Expected results per index
idx0 = [0.5, 0.5]
idx1 = torch.softmax(
torch.tensor([0.0, -2.1, 3.2], dtype=dtype),
dim=-1).tolist()
idx2 = 1 # Single element, has prob=1
# index=3 is empty. Should not matter.
idx4 = [1.0, 0.0] # softmax with -inf yields zero probability
np.testing.assert_allclose(
out.tolist(),
[idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]],
rtol=1e-05, atol=1e-10
)
...@@ -7,6 +7,7 @@ from .std import scatter_std ...@@ -7,6 +7,7 @@ from .std import scatter_std
from .max import scatter_max from .max import scatter_max
from .min import scatter_min from .min import scatter_min
from .logsumexp import scatter_logsumexp from .logsumexp import scatter_logsumexp
import torch_scatter.composite
__version__ = '1.3.2' __version__ = '1.3.2'
...@@ -20,5 +21,6 @@ __all__ = [ ...@@ -20,5 +21,6 @@ __all__ = [
'scatter_max', 'scatter_max',
'scatter_min', 'scatter_min',
'scatter_logsumexp', 'scatter_logsumexp',
'torch_scatter',
'__version__', '__version__',
] ]
...@@ -3,125 +3,84 @@ import torch ...@@ -3,125 +3,84 @@ import torch
from torch_scatter import scatter_add, scatter_max from torch_scatter import scatter_add, scatter_max
def scatter_log_softmax(src, index, dim=-1, dim_size=None): def scatter_softmax(src, index, dim=-1, eps=1e-12):
r""" r"""
Numerical safe log-softmax of all values from Softmax operation over all values in :attr:`src` tensor that share indices
the :attr:`src` tensor into :attr:`out` at the specified in the :attr:`index` tensor along a given axis :attr:`dim`.
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes For one-dimensional tensors, the operation computes
.. math:: .. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) = \mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i =
\mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j) \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`. :math:`\mathrm{index}_j = i`.
Compute a numerically safe log softmax operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args: Args:
src (Tensor): The source tensor. src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter. index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index. dim (int, optional): The axis along which to index.
(default: :obj:`-1`) (default: :obj:`-1`)
dim_size (int, optional): If :attr:`out` is not given, automatically eps (float, optional): Small value to ensure numerical stability.
create output with size :attr:`dim_size` at dimension :attr:`dim`. (default: :obj:`1e-12`)
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
if not torch.is_floating_point(src): if not torch.is_floating_point(src):
raise ValueError('log_softmax can be computed only over ' raise ValueError('`scatter_softmax` can only be computed over tensors '
'tensors with floating point data types.') 'with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
dim=dim,
dim_size=dim_size)
max_per_src_element = max_value_per_index.gather(dim, index) max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp()
sum_per_index = scatter_add( sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim)
src=recentered_scores.exp(), normalizing_constants = (sum_per_index + eps).gather(dim, index)
index=index,
dim=dim,
dim_size=dim_size
)
log_normalizing_constants = sum_per_index.log().gather(dim, index)
return recentered_scores - log_normalizing_constants return recentered_scores_exp / normalizing_constants
def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16): def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
r""" r"""
Numerical safe log-softmax of all values from Log-softmax operation over all values in :attr:`src` tensor that share
the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their :attr:`dim`.
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes For one-dimensional tensors, the operation computes
.. math:: .. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) = \mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
\frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)} \log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
\right)
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`. :math:`\mathrm{index}_j = i`.
Compute a numerically safe softmax operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args: Args:
src (Tensor): The source tensor. src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter. index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index. dim (int, optional): The axis along which to index.
(default: :obj:`-1`) (default: :obj:`-1`)
dim_size (int, optional): If :attr:`out` is not given, automatically eps (float, optional): Small value to ensure numerical stability.
create output with size :attr:`dim_size` at dimension :attr:`dim`. (default: :obj:`1e-12`)
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
if not torch.is_floating_point(src): if not torch.is_floating_point(src):
raise ValueError('softmax can be computed only over ' raise ValueError('`scatter_log_softmax` can only be computed over '
'tensors with floating point data types.') 'tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
dim=dim,
dim_size=dim_size)
max_per_src_element = max_value_per_index.gather(dim, index) max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element recentered_scores = src - max_per_src_element
exped_recentered_scores = recentered_scores.exp()
sum_per_index = scatter_add(src=recentered_scores.exp(), index=index,
sum_per_index = scatter_add( dim=dim)
src=exped_recentered_scores,
index=index, normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index)
dim=dim,
dim_size=dim_size return recentered_scores - normalizing_constants
)
normalizing_constant = (sum_per_index + epsilon).gather(dim, index)
return exped_recentered_scores / normalizing_constant
...@@ -4,26 +4,22 @@ from . import scatter_add, scatter_max ...@@ -4,26 +4,22 @@ from . import scatter_add, scatter_max
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
fill_value=None, epsilon=1e-16): fill_value=None, eps=1e-12):
r""" r"""Fills :attr:`out` with the log of summed exponentials of all values
Numerically safe logsumexp of all values from from the :attr:`src` tensor at the indices specified in the :attr:`index`
the :attr:`src` tensor into :attr:`out` at the tensor along a given axis :attr:`dim`.
indices specified in the :attr:`index` tensor along a given axis If multiple indices reference the same location, their
:attr:`dim`. If multiple indices reference the same location, their **exponential contributions add**
**contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`). (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes For one-dimensional tensors, the operation computes
.. math:: .. math::
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i) \mathrm{out}_i = \log \, \left( \exp(\mathrm{out}_i) + \sum_j
+ \sum_j \exp(\mathrm{src}_j) \right) \exp(\mathrm{src}_j) \right)
Compute a numerically safe logsumexp operation where :math:`\sum_j` is over :math:`j` such that
from the :attr:`src` tensor into :attr:`out` at the indices :math:`\mathrm{index}_j = i`.
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args: Args:
src (Tensor): The source tensor. src (Tensor): The source tensor.
...@@ -36,35 +32,23 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, ...@@ -36,35 +32,23 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
If :attr:`dim_size` is not given, a minimal sized output tensor is If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`) returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`, fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
the output tensor is filled with the smallest possible value of eps (float, optional): Small value to ensure numerical stability.
:obj:`src.dtype`. (default: :obj:`None`) (default: :obj:`1e-12`)
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
if not torch.is_floating_point(src): if not torch.is_floating_point(src):
raise ValueError('logsumexp can only be computed over ' raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.') 'tensors with floating point data types.')
if fill_value is None: max_value_per_index, _ = scatter_max(src, index, dim, out, dim_size,
fill_value = torch.finfo(src.dtype).min fill_value)
dim_size = out.shape[dim] \
if dim_size is None and out is not None else dim_size
max_value_per_index, _ = scatter_max(src, index, dim=dim,
out=out, dim_size=dim_size,
fill_value=fill_value)
max_per_src_element = max_value_per_index.gather(dim, index) max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element recentered_scores = src - max_per_src_element
out = (out - max_per_src_element).exp() if out is not None else None
sum_per_index = scatter_add(recentered_scores.exp(), index, dim, out,
dim_size, fill_value=0)
sum_per_index = scatter_add( return torch.log(sum_per_index + eps) + max_value_per_index
src=recentered_scores.exp(),
index=index,
dim=dim,
out=(out - max_per_src_element).exp() if out is not None else None,
dim_size=dim_size,
fill_value=0,
)
return torch.log(sum_per_index + epsilon) + max_value_per_index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment