Commit 62c61224 authored by rusty1s's avatar rusty1s
Browse files

clean up code base / added new functions to readme / added docs for softmax functions

parent d63eb9c9
......@@ -34,6 +34,12 @@ The package consists of the following operations:
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html)
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html)
In addition, we provide composite functions which make use of `scatter_*` operations under the hood:
* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax)
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
......
Scatter Softmax
===============
.. automodule:: torch_scatter.composite
:noindex:
.. autofunction:: scatter_softmax
.. autofunction:: scatter_log_softmax
......@@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i
:caption: Package reference
functions/*
composite/*
Indices and tables
==================
......
from itertools import product
import pytest
import torch
from torch_scatter.composite import scatter_log_softmax, scatter_softmax
from test.utils import devices, tensor, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_softmax(src, index)
out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1)
out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
dim=-1)
expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)
assert torch.allclose(out, expected)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_log_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_log_softmax(src, index)
out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1)
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
dim=-1)
expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)
assert torch.allclose(out, expected)
......@@ -4,27 +4,21 @@ import torch
import pytest
from torch_scatter import scatter_logsumexp
from .utils import devices, tensor
from .utils import devices, tensor, grad_dtypes
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_logsumexp(dtype, device):
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_logsumexp(src, index)
idx0 = torch.logsumexp(
torch.tensor([0.5, 0.5], dtype=dtype),
dim=-1).tolist()
idx1 = torch.logsumexp(
torch.tensor([0, -2.1, 3.2], dtype=dtype),
dim=-1).tolist()
idx2 = 7 # Single element
idx3 = torch.finfo(dtype).min # Empty index, returns yield value
idx4 = -1 # logsumexp with -inf is the identity
out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1)
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1)
out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
out4 = torch.tensor(-1, dtype=dtype)
assert out.tolist() == [idx0, idx1, idx2, idx3, idx4]
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
assert torch.allclose(out, expected)
from itertools import product
import numpy as np
import pytest
import torch
from torch_scatter.composite import scatter_log_softmax, scatter_softmax
from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
def test_log_softmax(dtype, device):
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')],
dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_log_softmax(src, index)
# Expected results per index
idx0 = [np.log(0.5), np.log(0.5)]
idx1 = torch.log_softmax(
torch.tensor([0.0, -2.1, 3.2], dtype=dtype),
dim=-1).tolist()
idx2 = 0.0 # Single element, has logprob=0
# index=3 is empty. Should not matter.
idx4 = [0.0, float('-inf')] # log_softmax with -inf preserves the -inf
np.testing.assert_allclose(
out.tolist(),
[idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]],
rtol=1e-05, atol=1e-10
)
@pytest.mark.parametrize('dtype,device',
product(SUPPORTED_FLOAT_DTYPES, devices))
def test_softmax(dtype, device):
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')],
dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_softmax(src, index)
# Expected results per index
idx0 = [0.5, 0.5]
idx1 = torch.softmax(
torch.tensor([0.0, -2.1, 3.2], dtype=dtype),
dim=-1).tolist()
idx2 = 1 # Single element, has prob=1
# index=3 is empty. Should not matter.
idx4 = [1.0, 0.0] # softmax with -inf yields zero probability
np.testing.assert_allclose(
out.tolist(),
[idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]],
rtol=1e-05, atol=1e-10
)
......@@ -7,6 +7,7 @@ from .std import scatter_std
from .max import scatter_max
from .min import scatter_min
from .logsumexp import scatter_logsumexp
import torch_scatter.composite
__version__ = '1.3.2'
......@@ -20,5 +21,6 @@ __all__ = [
'scatter_max',
'scatter_min',
'scatter_logsumexp',
'torch_scatter',
'__version__',
]
......@@ -3,125 +3,84 @@ import torch
from torch_scatter import scatter_add, scatter_max
def scatter_log_softmax(src, index, dim=-1, dim_size=None):
def scatter_softmax(src, index, dim=-1, eps=1e-12):
r"""
Numerical safe log-softmax of all values from
the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
Softmax operation over all values in :attr:`src` tensor that share indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) =
\mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
\mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i =
\frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Compute a numerically safe log softmax operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('log_softmax can be computed only over '
'tensors with floating point data types.')
raise ValueError('`scatter_softmax` can only be computed over tensors '
'with floating point data types.')
max_value_per_index, _ = scatter_max(src, index,
dim=dim,
dim_size=dim_size)
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp()
sum_per_index = scatter_add(
src=recentered_scores.exp(),
index=index,
dim=dim,
dim_size=dim_size
)
log_normalizing_constants = sum_per_index.log().gather(dim, index)
sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim)
normalizing_constants = (sum_per_index + eps).gather(dim, index)
return recentered_scores - log_normalizing_constants
return recentered_scores_exp / normalizing_constants
def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
r"""
Numerical safe log-softmax of all values from
the :attr:`src` tensor into :attr:`out` at the
Log-softmax operation over all values in :attr:`src` tensor that share
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
:attr:`dim`.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) =
\frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
\mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
\log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
\right)
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Compute a numerically safe softmax operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('softmax can be computed only over '
raise ValueError('`scatter_log_softmax` can only be computed over '
'tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index,
dim=dim,
dim_size=dim_size)
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
exped_recentered_scores = recentered_scores.exp()
sum_per_index = scatter_add(
src=exped_recentered_scores,
index=index,
dim=dim,
dim_size=dim_size
)
normalizing_constant = (sum_per_index + epsilon).gather(dim, index)
return exped_recentered_scores / normalizing_constant
sum_per_index = scatter_add(src=recentered_scores.exp(), index=index,
dim=dim)
normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index)
return recentered_scores - normalizing_constants
......@@ -4,26 +4,22 @@ from . import scatter_add, scatter_max
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
fill_value=None, epsilon=1e-16):
r"""
Numerically safe logsumexp of all values from
the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`).
fill_value=None, eps=1e-12):
r"""Fills :attr:`out` with the log of summed exponentials of all values
from the :attr:`src` tensor at the indices specified in the :attr:`index`
tensor along a given axis :attr:`dim`.
If multiple indices reference the same location, their
**exponential contributions add**
(`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i)
+ \sum_j \exp(\mathrm{src}_j) \right)
\mathrm{out}_i = \log \, \left( \exp(\mathrm{out}_i) + \sum_j
\exp(\mathrm{src}_j) \right)
Compute a numerically safe logsumexp operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
......@@ -36,35 +32,23 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('logsumexp can only be computed over '
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
if fill_value is None:
fill_value = torch.finfo(src.dtype).min
dim_size = out.shape[dim] \
if dim_size is None and out is not None else dim_size
max_value_per_index, _ = scatter_max(src, index, dim=dim,
out=out, dim_size=dim_size,
fill_value=fill_value)
max_value_per_index, _ = scatter_max(src, index, dim, out, dim_size,
fill_value)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
out = (out - max_per_src_element).exp() if out is not None else None
sum_per_index = scatter_add(recentered_scores.exp(), index, dim, out,
dim_size, fill_value=0)
sum_per_index = scatter_add(
src=recentered_scores.exp(),
index=index,
dim=dim,
out=(out - max_per_src_element).exp() if out is not None else None,
dim_size=dim_size,
fill_value=0,
)
return torch.log(sum_per_index + epsilon) + max_value_per_index
return torch.log(sum_per_index + eps) + max_value_per_index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment