Unverified Commit d790f1ca authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #77 from mallamanis/master

Implement scatter_logsumexp, scatter_softmax, scatter_log_softmax
parents 78a55495 62c61224
...@@ -34,6 +34,12 @@ The package consists of the following operations: ...@@ -34,6 +34,12 @@ The package consists of the following operations:
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html) * [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html)
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html) * [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html) * [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html)
In addition, we provide composite functions which make use of `scatter_*` operations under the hood:
* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax)
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations. All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
......
Scatter Softmax
===============
.. automodule:: torch_scatter.composite
:noindex:
.. autofunction:: scatter_softmax
.. autofunction:: scatter_log_softmax
Scatter LogSumExp
=================
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_logsumexp
...@@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i ...@@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i
:caption: Package reference :caption: Package reference
functions/* functions/*
composite/*
Indices and tables Indices and tables
================== ==================
......
from itertools import product
import pytest
import torch
from torch_scatter.composite import scatter_log_softmax, scatter_softmax
from test.utils import devices, tensor, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_softmax(src, index)
out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1)
out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
dim=-1)
expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)
assert torch.allclose(out, expected)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_log_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_log_softmax(src, index)
out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1)
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
dim=-1)
expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)
assert torch.allclose(out, expected)
from itertools import product
import torch
import pytest
from torch_scatter import scatter_logsumexp
from .utils import devices, tensor, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_logsumexp(dtype, device):
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_logsumexp(src, index)
out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1)
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1)
out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
out4 = torch.tensor(-1, dtype=dtype)
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
assert torch.allclose(out, expected)
...@@ -6,6 +6,8 @@ from .mean import scatter_mean ...@@ -6,6 +6,8 @@ from .mean import scatter_mean
from .std import scatter_std from .std import scatter_std
from .max import scatter_max from .max import scatter_max
from .min import scatter_min from .min import scatter_min
from .logsumexp import scatter_logsumexp
import torch_scatter.composite
__version__ = '1.3.2' __version__ = '1.3.2'
...@@ -18,5 +20,7 @@ __all__ = [ ...@@ -18,5 +20,7 @@ __all__ = [
'scatter_std', 'scatter_std',
'scatter_max', 'scatter_max',
'scatter_min', 'scatter_min',
'scatter_logsumexp',
'torch_scatter',
'__version__', '__version__',
] ]
from .softmax import scatter_log_softmax, scatter_softmax
__all__ = [
'scatter_softmax',
'scatter_log_softmax',
]
import torch
from torch_scatter import scatter_add, scatter_max
def scatter_softmax(src, index, dim=-1, eps=1e-12):
r"""
Softmax operation over all values in :attr:`src` tensor that share indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i =
\frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('`scatter_softmax` can only be computed over tensors '
'with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp()
sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim)
normalizing_constants = (sum_per_index + eps).gather(dim, index)
return recentered_scores_exp / normalizing_constants
def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
r"""
Log-softmax operation over all values in :attr:`src` tensor that share
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
\log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
\right)
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('`scatter_log_softmax` can only be computed over '
'tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
sum_per_index = scatter_add(src=recentered_scores.exp(), index=index,
dim=dim)
normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index)
return recentered_scores - normalizing_constants
import torch
from . import scatter_add, scatter_max
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
fill_value=None, eps=1e-12):
r"""Fills :attr:`out` with the log of summed exponentials of all values
from the :attr:`src` tensor at the indices specified in the :attr:`index`
tensor along a given axis :attr:`dim`.
If multiple indices reference the same location, their
**exponential contributions add**
(`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \log \, \left( \exp(\mathrm{out}_i) + \sum_j
\exp(\mathrm{src}_j) \right)
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim, out, dim_size,
fill_value)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
out = (out - max_per_src_element).exp() if out is not None else None
sum_per_index = scatter_add(recentered_scores.exp(), index, dim, out,
dim_size, fill_value=0)
return torch.log(sum_per_index + eps) + max_value_per_index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment