Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
62c61224
Commit
62c61224
authored
Nov 08, 2019
by
rusty1s
Browse files
clean up code base / added new functions to readme / added docs for softmax functions
parent
d63eb9c9
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
129 additions
and
187 deletions
+129
-187
README.md
README.md
+6
-0
docs/source/composite/softmax.rst
docs/source/composite/softmax.rst
+9
-0
docs/source/index.rst
docs/source/index.rst
+1
-0
test/composite/test_softmax.py
test/composite/test_softmax.py
+47
-0
test/test_logsumexp.py
test/test_logsumexp.py
+9
-15
test/test_softmax.py
test/test_softmax.py
+0
-60
torch_scatter/__init__.py
torch_scatter/__init__.py
+2
-0
torch_scatter/composite/softmax.py
torch_scatter/composite/softmax.py
+33
-74
torch_scatter/logsumexp.py
torch_scatter/logsumexp.py
+22
-38
No files found.
README.md
View file @
62c61224
...
...
@@ -34,6 +34,12 @@ The package consists of the following operations:
*
[
**Scatter Std**
](
https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html
)
*
[
**Scatter Min**
](
https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html
)
*
[
**Scatter Max**
](
https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html
)
*
[
**Scatter LogSumExp**
](
https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html
)
In addition, we provide composite functions which make use of
`scatter_*`
operations under the hood:
*
[
**Scatter Softmax**
](
https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax
)
*
[
**Scatter LogSoftmax**
](
https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax
)
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
...
...
docs/source/composite/softmax.rst
0 → 100644
View file @
62c61224
Scatter Softmax
===============
.. automodule:: torch_scatter.composite
:noindex:
.. autofunction:: scatter_softmax
.. autofunction:: scatter_log_softmax
docs/source/index.rst
View file @
62c61224
...
...
@@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i
:caption: Package reference
functions/*
composite/*
Indices and tables
==================
...
...
test/composite/test_softmax.py
0 → 100644
View file @
62c61224
from
itertools
import
product
import
pytest
import
torch
from
torch_scatter.composite
import
scatter_log_softmax
,
scatter_softmax
from
test.utils
import
devices
,
tensor
,
grad_dtypes
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_softmax
(
dtype
,
device
):
src
=
tensor
([
0.2
,
0
,
0.2
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
index
=
tensor
([
0
,
1
,
0
,
1
,
1
,
2
,
4
,
4
],
torch
.
long
,
device
)
out
=
scatter_softmax
(
src
,
index
)
out0
=
torch
.
softmax
(
torch
.
tensor
([
0.2
,
0.2
],
dtype
=
dtype
),
dim
=-
1
)
out1
=
torch
.
softmax
(
torch
.
tensor
([
0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
)
out2
=
torch
.
softmax
(
torch
.
tensor
([
7
],
dtype
=
dtype
),
dim
=-
1
)
out4
=
torch
.
softmax
(
torch
.
tensor
([
-
1
,
float
(
'-inf'
)],
dtype
=
dtype
),
dim
=-
1
)
expected
=
torch
.
stack
([
out0
[
0
],
out1
[
0
],
out0
[
1
],
out1
[
1
],
out1
[
2
],
out2
[
0
],
out4
[
0
],
out4
[
1
]
],
dim
=
0
)
assert
torch
.
allclose
(
out
,
expected
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_log_softmax
(
dtype
,
device
):
src
=
tensor
([
0.2
,
0
,
0.2
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
index
=
tensor
([
0
,
1
,
0
,
1
,
1
,
2
,
4
,
4
],
torch
.
long
,
device
)
out
=
scatter_log_softmax
(
src
,
index
)
out0
=
torch
.
log_softmax
(
torch
.
tensor
([
0.2
,
0.2
],
dtype
=
dtype
),
dim
=-
1
)
out1
=
torch
.
log_softmax
(
torch
.
tensor
([
0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
)
out2
=
torch
.
log_softmax
(
torch
.
tensor
([
7
],
dtype
=
dtype
),
dim
=-
1
)
out4
=
torch
.
log_softmax
(
torch
.
tensor
([
-
1
,
float
(
'-inf'
)],
dtype
=
dtype
),
dim
=-
1
)
expected
=
torch
.
stack
([
out0
[
0
],
out1
[
0
],
out0
[
1
],
out1
[
1
],
out1
[
2
],
out2
[
0
],
out4
[
0
],
out4
[
1
]
],
dim
=
0
)
assert
torch
.
allclose
(
out
,
expected
)
test/test_logsumexp.py
View file @
62c61224
...
...
@@ -4,27 +4,21 @@ import torch
import
pytest
from
torch_scatter
import
scatter_logsumexp
from
.utils
import
devices
,
tensor
from
.utils
import
devices
,
tensor
,
grad_dtypes
SUPPORTED_FLOAT_DTYPES
=
{
torch
.
float32
,
torch
.
float64
}
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_logsumexp
(
dtype
,
device
):
src
=
tensor
([
0.5
,
0
,
0.5
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
index
=
tensor
([
0
,
1
,
0
,
1
,
1
,
2
,
4
,
4
],
torch
.
long
,
device
)
out
=
scatter_logsumexp
(
src
,
index
)
idx0
=
torch
.
logsumexp
(
torch
.
tensor
([
0.5
,
0.5
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
idx1
=
torch
.
logsumexp
(
torch
.
tensor
([
0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
idx2
=
7
# Single element
idx3
=
torch
.
finfo
(
dtype
).
min
# Empty index, returns yield value
idx4
=
-
1
# logsumexp with -inf is the identity
out0
=
torch
.
logsumexp
(
torch
.
tensor
([
0.5
,
0.5
],
dtype
=
dtype
),
dim
=-
1
)
out1
=
torch
.
logsumexp
(
torch
.
tensor
([
0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
)
out2
=
torch
.
logsumexp
(
torch
.
tensor
(
7
,
dtype
=
dtype
),
dim
=-
1
)
out3
=
torch
.
tensor
(
torch
.
finfo
(
dtype
).
min
,
dtype
=
dtype
)
out4
=
torch
.
tensor
(
-
1
,
dtype
=
dtype
)
assert
out
.
tolist
()
==
[
idx0
,
idx1
,
idx2
,
idx3
,
idx4
]
expected
=
torch
.
stack
([
out0
,
out1
,
out2
,
out3
,
out4
],
dim
=
0
)
assert
torch
.
allclose
(
out
,
expected
)
test/test_softmax.py
deleted
100644 → 0
View file @
d63eb9c9
from
itertools
import
product
import
numpy
as
np
import
pytest
import
torch
from
torch_scatter.composite
import
scatter_log_softmax
,
scatter_softmax
from
.utils
import
devices
,
tensor
SUPPORTED_FLOAT_DTYPES
=
{
torch
.
float32
,
torch
.
float64
}
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
def
test_log_softmax
(
dtype
,
device
):
src
=
tensor
([
0.25
,
0
,
0.25
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
index
=
tensor
([
0
,
1
,
0
,
1
,
1
,
2
,
4
,
4
],
torch
.
long
,
device
)
out
=
scatter_log_softmax
(
src
,
index
)
# Expected results per index
idx0
=
[
np
.
log
(
0.5
),
np
.
log
(
0.5
)]
idx1
=
torch
.
log_softmax
(
torch
.
tensor
([
0.0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
idx2
=
0.0
# Single element, has logprob=0
# index=3 is empty. Should not matter.
idx4
=
[
0.0
,
float
(
'-inf'
)]
# log_softmax with -inf preserves the -inf
np
.
testing
.
assert_allclose
(
out
.
tolist
(),
[
idx0
[
0
],
idx1
[
0
],
idx0
[
1
],
idx1
[
1
],
idx1
[
2
],
idx2
,
idx4
[
0
],
idx4
[
1
]],
rtol
=
1e-05
,
atol
=
1e-10
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
def
test_softmax
(
dtype
,
device
):
src
=
tensor
([
0.25
,
0
,
0.25
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
index
=
tensor
([
0
,
1
,
0
,
1
,
1
,
2
,
4
,
4
],
torch
.
long
,
device
)
out
=
scatter_softmax
(
src
,
index
)
# Expected results per index
idx0
=
[
0.5
,
0.5
]
idx1
=
torch
.
softmax
(
torch
.
tensor
([
0.0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
idx2
=
1
# Single element, has prob=1
# index=3 is empty. Should not matter.
idx4
=
[
1.0
,
0.0
]
# softmax with -inf yields zero probability
np
.
testing
.
assert_allclose
(
out
.
tolist
(),
[
idx0
[
0
],
idx1
[
0
],
idx0
[
1
],
idx1
[
1
],
idx1
[
2
],
idx2
,
idx4
[
0
],
idx4
[
1
]],
rtol
=
1e-05
,
atol
=
1e-10
)
torch_scatter/__init__.py
View file @
62c61224
...
...
@@ -7,6 +7,7 @@ from .std import scatter_std
from
.max
import
scatter_max
from
.min
import
scatter_min
from
.logsumexp
import
scatter_logsumexp
import
torch_scatter.composite
__version__
=
'1.3.2'
...
...
@@ -20,5 +21,6 @@ __all__ = [
'scatter_max'
,
'scatter_min'
,
'scatter_logsumexp'
,
'torch_scatter'
,
'__version__'
,
]
torch_scatter/composite/softmax.py
View file @
62c61224
...
...
@@ -3,125 +3,84 @@ import torch
from
torch_scatter
import
scatter_add
,
scatter_max
def
scatter_
log_
softmax
(
src
,
index
,
dim
=-
1
,
dim_size
=
None
):
def
scatter_softmax
(
src
,
index
,
dim
=-
1
,
eps
=
1e-12
):
r
"""
Numerical safe log-softmax of all values from
the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
Softmax operation over all values in :attr:`src` tensor that share indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i
)
=
\mathrm{src}_i
- \mathrm{logsumexp}_j (
\mathrm{src}_j)
\mathrm{out}_i =
{\textrm{
softmax
}
(\mathrm{src}
)}
_i =
\frac{\exp(
\mathrm{src}_i
)}{\sum_j \exp(
\mathrm{src}_j)
}
where :math:`\
mathrm{logsumexp}
_j` is over :math:`j` such that
where :math:`\
sum
_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Compute a numerically safe log softmax operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if
not
torch
.
is_floating_point
(
src
):
raise
ValueError
(
'
log
_softmax can be computed o
nly over
'
'
tensors
with floating point data types.'
)
raise
ValueError
(
'
`scatter
_softmax
`
can
only
be computed o
ver tensors
'
'with floating point data types.'
)
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
dim_size
=
dim_size
)
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
fill_value
=
0
)
max_per_src_element
=
max_value_per_index
.
gather
(
dim
,
index
)
recentered_scores
=
src
-
max_per_src_element
recentered_scores_exp
=
recentered_scores
.
exp
()
sum_per_index
=
scatter_add
(
src
=
recentered_scores
.
exp
(),
index
=
index
,
dim
=
dim
,
dim_size
=
dim_size
)
log_normalizing_constants
=
sum_per_index
.
log
().
gather
(
dim
,
index
)
sum_per_index
=
scatter_add
(
recentered_scores_exp
,
index
,
dim
=
dim
)
normalizing_constants
=
(
sum_per_index
+
eps
).
gather
(
dim
,
index
)
return
recentered_scores
-
log_
normalizing_constants
return
recentered_scores
_exp
/
normalizing_constants
def
scatter_softmax
(
src
,
index
,
dim
=-
1
,
dim_size
=
None
,
epsilon
=
1e-1
6
):
def
scatter_
log_
softmax
(
src
,
index
,
dim
=-
1
,
eps
=
1e-1
2
):
r
"""
Numerical safe log-softmax of all values from
the :attr:`src` tensor into :attr:`out` at the
Log-softmax operation over all values in :attr:`src` tensor that share
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
:attr:`dim`.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) =
\frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
\mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
\log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
\right)
where :math:`\
mathrm{logsumexp}
_j` is over :math:`j` such that
where :math:`\
sum
_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Compute a numerically safe softmax operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if
not
torch
.
is_floating_point
(
src
):
raise
ValueError
(
'softmax can be computed
only
over '
raise
ValueError
(
'
`scatter_log_
softmax
`
can
only
be computed over '
'tensors with floating point data types.'
)
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
dim_size
=
dim_size
)
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
fill_value
=
0
)
max_per_src_element
=
max_value_per_index
.
gather
(
dim
,
index
)
recentered_scores
=
src
-
max_per_src_element
exped_recentered_scores
=
recentered_scores
.
exp
()
sum_per_index
=
scatter_add
(
src
=
exped_recentered_scores
,
index
=
index
,
dim
=
dim
,
dim_size
=
dim_size
)
normalizing_constant
=
(
sum_per_index
+
epsilon
).
gather
(
dim
,
index
)
return
exped_recentered_scores
/
normalizing_constant
sum_per_index
=
scatter_add
(
src
=
recentered_scores
.
exp
(),
index
=
index
,
dim
=
dim
)
normalizing_constants
=
torch
.
log
(
sum_per_index
+
eps
).
gather
(
dim
,
index
)
return
recentered_scores
-
normalizing_constants
torch_scatter/logsumexp.py
View file @
62c61224
...
...
@@ -4,26 +4,22 @@ from . import scatter_add, scatter_max
def
scatter_logsumexp
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
None
,
eps
ilon
=
1e-1
6
):
r
"""
Numerically safe logsumexp of all values from
t
he :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions logsumexp**
(`cf.` :meth:`~torch_scatter.scatter_add`).
fill_value
=
None
,
eps
=
1e-1
2
):
r
"""
Fills :attr:`out` with the log of summed exponentials of all values
from the :attr:`src` tensor at the indices specified in the :attr:`index`
t
ensor along a given axis :attr:`dim`.
If multiple indices reference the same location, their
**exponential contributions add**
(`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i)
+ \sum_j
\exp(\mathrm{src}_j) \right)
\mathrm{out}_i = \log
\,
\left( \exp(\mathrm{out}_i)
+ \sum_j
\exp(\mathrm{src}_j) \right)
Compute a numerically safe logsumexp operation
from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
...
...
@@ -36,35 +32,23 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`.
If set to
:obj:`None`
,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`.
(default: :obj:`
None
`)
fill output tensor with :attr:`fill_value`.
(default:
:obj:`None`
)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`
1e-12
`)
:rtype: :class:`Tensor`
"""
if
not
torch
.
is_floating_point
(
src
):
raise
ValueError
(
'logsumexp can only be computed over '
raise
ValueError
(
'
`scatter_
logsumexp
`
can only be computed over '
'tensors with floating point data types.'
)
if
fill_value
is
None
:
fill_value
=
torch
.
finfo
(
src
.
dtype
).
min
dim_size
=
out
.
shape
[
dim
]
\
if
dim_size
is
None
and
out
is
not
None
else
dim_size
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
out
=
out
,
dim_size
=
dim_size
,
fill_value
=
fill_value
)
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)
max_per_src_element
=
max_value_per_index
.
gather
(
dim
,
index
)
recentered_scores
=
src
-
max_per_src_element
out
=
(
out
-
max_per_src_element
).
exp
()
if
out
is
not
None
else
None
sum_per_index
=
scatter_add
(
recentered_scores
.
exp
(),
index
,
dim
,
out
,
dim_size
,
fill_value
=
0
)
sum_per_index
=
scatter_add
(
src
=
recentered_scores
.
exp
(),
index
=
index
,
dim
=
dim
,
out
=
(
out
-
max_per_src_element
).
exp
()
if
out
is
not
None
else
None
,
dim_size
=
dim_size
,
fill_value
=
0
,
)
return
torch
.
log
(
sum_per_index
+
epsilon
)
+
max_value_per_index
return
torch
.
log
(
sum_per_index
+
eps
)
+
max_value_per_index
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment