Unverified Commit d790f1ca authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #77 from mallamanis/master

Implement scatter_logsumexp, scatter_softmax, scatter_log_softmax
parents 78a55495 62c61224
......@@ -34,6 +34,12 @@ The package consists of the following operations:
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html)
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html)
In addition, we provide composite functions which make use of `scatter_*` operations under the hood:
* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax)
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
......
Scatter Softmax
===============
.. automodule:: torch_scatter.composite
:noindex:
.. autofunction:: scatter_softmax
.. autofunction:: scatter_log_softmax
Scatter LogSumExp
=================
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_logsumexp
......@@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i
:caption: Package reference
functions/*
composite/*
Indices and tables
==================
......
from itertools import product
import pytest
import torch
from torch_scatter.composite import scatter_log_softmax, scatter_softmax
from test.utils import devices, tensor, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_softmax(src, index)
out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1)
out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
dim=-1)
expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)
assert torch.allclose(out, expected)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_log_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_log_softmax(src, index)
out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1)
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
dim=-1)
expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)
assert torch.allclose(out, expected)
from itertools import product
import torch
import pytest
from torch_scatter import scatter_logsumexp
from .utils import devices, tensor, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_logsumexp(dtype, device):
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_logsumexp(src, index)
out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1)
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1)
out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
out4 = torch.tensor(-1, dtype=dtype)
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
assert torch.allclose(out, expected)
......@@ -6,6 +6,8 @@ from .mean import scatter_mean
from .std import scatter_std
from .max import scatter_max
from .min import scatter_min
from .logsumexp import scatter_logsumexp
import torch_scatter.composite
__version__ = '1.3.2'
......@@ -18,5 +20,7 @@ __all__ = [
'scatter_std',
'scatter_max',
'scatter_min',
'scatter_logsumexp',
'torch_scatter',
'__version__',
]
from .softmax import scatter_log_softmax, scatter_softmax
__all__ = [
'scatter_softmax',
'scatter_log_softmax',
]
import torch
from torch_scatter import scatter_add, scatter_max
def scatter_softmax(src, index, dim=-1, eps=1e-12):
r"""
Softmax operation over all values in :attr:`src` tensor that share indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i =
\frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('`scatter_softmax` can only be computed over tensors '
'with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp()
sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim)
normalizing_constants = (sum_per_index + eps).gather(dim, index)
return recentered_scores_exp / normalizing_constants
def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
r"""
Log-softmax operation over all values in :attr:`src` tensor that share
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
\log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
\right)
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('`scatter_log_softmax` can only be computed over '
'tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
sum_per_index = scatter_add(src=recentered_scores.exp(), index=index,
dim=dim)
normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index)
return recentered_scores - normalizing_constants
import torch
from . import scatter_add, scatter_max
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
fill_value=None, eps=1e-12):
r"""Fills :attr:`out` with the log of summed exponentials of all values
from the :attr:`src` tensor at the indices specified in the :attr:`index`
tensor along a given axis :attr:`dim`.
If multiple indices reference the same location, their
**exponential contributions add**
(`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \log \, \left( \exp(\mathrm{out}_i) + \sum_j
\exp(\mathrm{src}_j) \right)
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim, out, dim_size,
fill_value)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
out = (out - max_per_src_element).exp() if out is not None else None
sum_per_index = scatter_add(recentered_scores.exp(), index, dim, out,
dim_size, fill_value=0)
return torch.log(sum_per_index + eps) + max_value_per_index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment