"packaging/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "226126b83b0cb05c0fe26b3d3c8ee0f9c8f0d992"
Commit dd50d35f authored by Miltos Allamanis's avatar Miltos Allamanis
Browse files

A first round of implementation of scatter_logsumexp/softmax/logsoftmax ops.

parent 78a55495
Scatter LogSumExp
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_logsumexp
...@@ -6,6 +6,7 @@ from .mean import scatter_mean ...@@ -6,6 +6,7 @@ from .mean import scatter_mean
from .std import scatter_std from .std import scatter_std
from .max import scatter_max from .max import scatter_max
from .min import scatter_min from .min import scatter_min
from .logsumexp import scatter_logsumexp
__version__ = '1.3.2' __version__ = '1.3.2'
...@@ -18,5 +19,6 @@ __all__ = [ ...@@ -18,5 +19,6 @@ __all__ = [
'scatter_std', 'scatter_std',
'scatter_max', 'scatter_max',
'scatter_min', 'scatter_min',
'scatter_logsumexp',
'__version__', '__version__',
] ]
from .softmax import scatter_log_softmax, scatter_softmax
__all__ = [
'scatter_softmax',
'scatter_log_softmax'
]
\ No newline at end of file
import torch
from torch_scatter.logsumexp import _scatter_logsumexp
def scatter_log_softmax(src, index, dim=-1, dim_size=None):
r"""
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) = \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Compute a numerically safe log softmax operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
per_index_logsumexp, recentered_src = _scatter_logsumexp(src, index, dim=dim, dim_size=dim_size)
return recentered_src - per_index_logsumexp.gather(dim, index)
def scatter_softmax(src, index, dim=-1, dim_size=None):
r"""
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) = \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Compute a numerically safe softmax operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
return scatter_log_softmax(src, index, dim, dim_size).exp()
import torch
from . import scatter_add, scatter_max
EPSILON = 1e-16
def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
if not torch.is_floating_point(src):
raise ValueError('logsumexp can be computed over tensors floating point data types.')
if fill_value is None:
fill_value = torch.finfo(src.dtype).min
dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size
max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
sum_per_index = scatter_add(
src=recentered_scores.exp(),
index=index,
dim=dim,
out=(src - max_per_src_element).exp() if out is not None else None,
dim_size=dim_size,
fill_value=fill_value,
)
return torch.log(sum_per_index + EPSILON) + max_value_per_index, recentered_scores
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
r"""
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i) + \sum_j \exp(\mathrm{src}_j) \right)
Compute a numerically safe logsumexp operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment