Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
dd50d35f
Commit
dd50d35f
authored
Nov 04, 2019
by
Miltos Allamanis
Browse files
A first round of implementation of scatter_logsumexp/softmax/logsoftmax ops.
parent
78a55495
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
166 additions
and
0 deletions
+166
-0
docs/source/functions/logsumexp.rst
docs/source/functions/logsumexp.rst
+7
-0
torch_scatter/__init__.py
torch_scatter/__init__.py
+2
-0
torch_scatter/composite/__init__.py
torch_scatter/composite/__init__.py
+6
-0
torch_scatter/composite/softmax.py
torch_scatter/composite/softmax.py
+85
-0
torch_scatter/logsumexp.py
torch_scatter/logsumexp.py
+66
-0
No files found.
docs/source/functions/logsumexp.rst
0 → 100644
View file @
dd50d35f
Scatter LogSumExp
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_logsumexp
torch_scatter/__init__.py
View file @
dd50d35f
...
@@ -6,6 +6,7 @@ from .mean import scatter_mean
...
@@ -6,6 +6,7 @@ from .mean import scatter_mean
from
.std
import
scatter_std
from
.std
import
scatter_std
from
.max
import
scatter_max
from
.max
import
scatter_max
from
.min
import
scatter_min
from
.min
import
scatter_min
from
.logsumexp
import
scatter_logsumexp
__version__
=
'1.3.2'
__version__
=
'1.3.2'
...
@@ -18,5 +19,6 @@ __all__ = [
...
@@ -18,5 +19,6 @@ __all__ = [
'scatter_std'
,
'scatter_std'
,
'scatter_max'
,
'scatter_max'
,
'scatter_min'
,
'scatter_min'
,
'scatter_logsumexp'
,
'__version__'
,
'__version__'
,
]
]
torch_scatter/composite/__init__.py
0 → 100644
View file @
dd50d35f
from
.softmax
import
scatter_log_softmax
,
scatter_softmax
__all__
=
[
'scatter_softmax'
,
'scatter_log_softmax'
]
\ No newline at end of file
torch_scatter/composite/softmax.py
0 → 100644
View file @
dd50d35f
import
torch
from
torch_scatter.logsumexp
import
_scatter_logsumexp
def
scatter_log_softmax
(
src
,
index
,
dim
=-
1
,
dim_size
=
None
):
r
"""
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) = \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Compute a numerically safe log softmax operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
per_index_logsumexp
,
recentered_src
=
_scatter_logsumexp
(
src
,
index
,
dim
=
dim
,
dim_size
=
dim_size
)
return
recentered_src
-
per_index_logsumexp
.
gather
(
dim
,
index
)
def
scatter_softmax
(
src
,
index
,
dim
=-
1
,
dim_size
=
None
):
r
"""
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) = \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Compute a numerically safe softmax operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
return
scatter_log_softmax
(
src
,
index
,
dim
,
dim_size
).
exp
()
torch_scatter/logsumexp.py
0 → 100644
View file @
dd50d35f
import
torch
from
.
import
scatter_add
,
scatter_max
EPSILON
=
1e-16
def
_scatter_logsumexp
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
None
):
if
not
torch
.
is_floating_point
(
src
):
raise
ValueError
(
'logsumexp can be computed over tensors floating point data types.'
)
if
fill_value
is
None
:
fill_value
=
torch
.
finfo
(
src
.
dtype
).
min
dim_size
=
out
.
shape
[
dim
]
if
dim_size
is
None
and
out
is
not
None
else
dim_size
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
out
=
out
,
dim_size
=
dim_size
,
fill_value
=
fill_value
)
max_per_src_element
=
max_value_per_index
.
gather
(
dim
,
index
)
recentered_scores
=
src
-
max_per_src_element
sum_per_index
=
scatter_add
(
src
=
recentered_scores
.
exp
(),
index
=
index
,
dim
=
dim
,
out
=
(
src
-
max_per_src_element
).
exp
()
if
out
is
not
None
else
None
,
dim_size
=
dim_size
,
fill_value
=
fill_value
,
)
return
torch
.
log
(
sum_per_index
+
EPSILON
)
+
max_value_per_index
,
recentered_scores
def
scatter_logsumexp
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
None
):
r
"""
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i) + \sum_j \exp(\mathrm{src}_j) \right)
Compute a numerically safe logsumexp operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
return
_scatter_logsumexp
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)[
0
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment