Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
0c127881
Commit
0c127881
authored
Nov 05, 2019
by
Miltos Allamanis
Browse files
Address most flake8, pycodestyle errors.
parent
7b14c671
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
16 deletions
+35
-16
test/test_logsumexp.py
test/test_logsumexp.py
+2
-1
test/test_softmax.py
test/test_softmax.py
+2
-1
torch_scatter/composite/softmax.py
torch_scatter/composite/softmax.py
+17
-8
torch_scatter/logsumexp.py
torch_scatter/logsumexp.py
+14
-6
No files found.
test/test_logsumexp.py
View file @
0c127881
...
@@ -2,12 +2,13 @@ from itertools import product
...
@@ -2,12 +2,13 @@ from itertools import product
import
torch
import
torch
import
pytest
import
pytest
from
torch_scatter
import
scatter_max
,
scatter_logsumexp
from
torch_scatter
import
scatter_logsumexp
from
.utils
import
devices
,
tensor
from
.utils
import
devices
,
tensor
SUPPORTED_FLOAT_DTYPES
=
{
torch
.
float32
,
torch
.
float64
}
SUPPORTED_FLOAT_DTYPES
=
{
torch
.
float32
,
torch
.
float64
}
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
def
test_logsumexp
(
dtype
,
device
):
def
test_logsumexp
(
dtype
,
device
):
src
=
tensor
([
0.5
,
0
,
0.5
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
src
=
tensor
([
0.5
,
0
,
0.5
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
...
...
test/test_softmax.py
View file @
0c127881
...
@@ -9,6 +9,7 @@ from .utils import devices, tensor
...
@@ -9,6 +9,7 @@ from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES
=
{
torch
.
float32
,
torch
.
float64
}
SUPPORTED_FLOAT_DTYPES
=
{
torch
.
float32
,
torch
.
float64
}
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
def
test_log_softmax
(
dtype
,
device
):
def
test_log_softmax
(
dtype
,
device
):
src
=
tensor
([
0.25
,
0
,
0.25
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
src
=
tensor
([
0.25
,
0
,
0.25
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
...
@@ -48,4 +49,4 @@ def test_softmax(dtype, device):
...
@@ -48,4 +49,4 @@ def test_softmax(dtype, device):
out
.
tolist
(),
out
.
tolist
(),
[
idx0
[
0
],
idx1
[
0
],
idx0
[
1
],
idx1
[
1
],
idx1
[
2
],
idx2
,
idx4
[
0
],
idx4
[
1
]],
[
idx0
[
0
],
idx1
[
0
],
idx0
[
1
],
idx1
[
1
],
idx1
[
2
],
idx2
,
idx4
[
0
],
idx4
[
1
]],
rtol
=
1e-05
,
atol
=
1e-10
rtol
=
1e-05
,
atol
=
1e-10
)
)
\ No newline at end of file
torch_scatter/composite/softmax.py
View file @
0c127881
...
@@ -2,9 +2,11 @@ import torch
...
@@ -2,9 +2,11 @@ import torch
from
torch_scatter
import
scatter_add
,
scatter_max
from
torch_scatter
import
scatter_add
,
scatter_max
def
scatter_log_softmax
(
src
,
index
,
dim
=-
1
,
dim_size
=
None
):
def
scatter_log_softmax
(
src
,
index
,
dim
=-
1
,
dim_size
=
None
):
r
"""
r
"""
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
Numerical safe log-softmax of all values from
the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
:attr:`dim`.If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
...
@@ -12,7 +14,7 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
...
@@ -12,7 +14,7 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
For one-dimensional tensors, the operation computes
For one-dimensional tensors, the operation computes
.. math::
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) =
\mathrm{out}_i = softmax(\mathrm{src}_i) =
\mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
\mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
...
@@ -42,9 +44,12 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
...
@@ -42,9 +44,12 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
:rtype: :class:`Tensor`
:rtype: :class:`Tensor`
"""
"""
if
not
torch
.
is_floating_point
(
src
):
if
not
torch
.
is_floating_point
(
src
):
raise
ValueError
(
'log_softmax can be computed only over tensors with floating point data types.'
)
raise
ValueError
(
'log_softmax can be computed only over '
'tensors with floating point data types.'
)
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
dim_size
=
dim_size
)
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
dim_size
=
dim_size
)
max_per_src_element
=
max_value_per_index
.
gather
(
dim
,
index
)
max_per_src_element
=
max_value_per_index
.
gather
(
dim
,
index
)
recentered_scores
=
src
-
max_per_src_element
recentered_scores
=
src
-
max_per_src_element
...
@@ -62,7 +67,8 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
...
@@ -62,7 +67,8 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
def
scatter_softmax
(
src
,
index
,
dim
=-
1
,
dim_size
=
None
,
epsilon
=
1e-16
):
def
scatter_softmax
(
src
,
index
,
dim
=-
1
,
dim_size
=
None
,
epsilon
=
1e-16
):
r
"""
r
"""
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
Numerical safe log-softmax of all values from
the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
:attr:`dim`. If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
...
@@ -70,7 +76,7 @@ def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
...
@@ -70,7 +76,7 @@ def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
For one-dimensional tensors, the operation computes
For one-dimensional tensors, the operation computes
.. math::
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) =
\mathrm{out}_i = softmax(\mathrm{src}_i) =
\frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
\frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
...
@@ -100,9 +106,12 @@ def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
...
@@ -100,9 +106,12 @@ def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
:rtype: :class:`Tensor`
:rtype: :class:`Tensor`
"""
"""
if
not
torch
.
is_floating_point
(
src
):
if
not
torch
.
is_floating_point
(
src
):
raise
ValueError
(
'softmax can be computed only over tensors with floating point data types.'
)
raise
ValueError
(
'softmax can be computed only over '
'tensors with floating point data types.'
)
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
dim_size
=
dim_size
)
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
dim_size
=
dim_size
)
max_per_src_element
=
max_value_per_index
.
gather
(
dim
,
index
)
max_per_src_element
=
max_value_per_index
.
gather
(
dim
,
index
)
recentered_scores
=
src
-
max_per_src_element
recentered_scores
=
src
-
max_per_src_element
...
...
torch_scatter/logsumexp.py
View file @
0c127881
...
@@ -3,9 +3,11 @@ import torch
...
@@ -3,9 +3,11 @@ import torch
from
.
import
scatter_add
,
scatter_max
from
.
import
scatter_add
,
scatter_max
def
scatter_logsumexp
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
None
,
epsilon
=
1e-16
):
def
scatter_logsumexp
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
None
,
epsilon
=
1e-16
):
r
"""
r
"""
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
Numerically safe logsumexp of all values from
the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`. If multiple indices reference the same location, their
:attr:`dim`. If multiple indices reference the same location, their
**contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`).
**contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`).
...
@@ -13,7 +15,8 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
...
@@ -13,7 +15,8 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
For one-dimensional tensors, the operation computes
For one-dimensional tensors, the operation computes
.. math::
.. math::
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i) + \sum_j \exp(\mathrm{src}_j) \right)
\mathrm{out}_i = \log \left( \exp(\mathrm{out}_i)
+ \sum_j \exp(\mathrm{src}_j) \right)
Compute a numerically safe logsumexp operation
Compute a numerically safe logsumexp operation
from the :attr:`src` tensor into :attr:`out` at the indices
from the :attr:`src` tensor into :attr:`out` at the indices
...
@@ -40,13 +43,18 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
...
@@ -40,13 +43,18 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
:rtype: :class:`Tensor`
:rtype: :class:`Tensor`
"""
"""
if
not
torch
.
is_floating_point
(
src
):
if
not
torch
.
is_floating_point
(
src
):
raise
ValueError
(
'logsumexp can be computed over tensors with floating point data types.'
)
raise
ValueError
(
'logsumexp can only be computed over '
'tensors with floating point data types.'
)
if
fill_value
is
None
:
if
fill_value
is
None
:
fill_value
=
torch
.
finfo
(
src
.
dtype
).
min
fill_value
=
torch
.
finfo
(
src
.
dtype
).
min
dim_size
=
out
.
shape
[
dim
]
if
dim_size
is
None
and
out
is
not
None
else
dim_size
dim_size
=
out
.
shape
[
dim
]
\
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
out
=
out
,
dim_size
=
dim_size
,
fill_value
=
fill_value
)
if
dim_size
is
None
and
out
is
not
None
else
dim_size
max_value_per_index
,
_
=
scatter_max
(
src
,
index
,
dim
=
dim
,
out
=
out
,
dim_size
=
dim_size
,
fill_value
=
fill_value
)
max_per_src_element
=
max_value_per_index
.
gather
(
dim
,
index
)
max_per_src_element
=
max_value_per_index
.
gather
(
dim
,
index
)
recentered_scores
=
src
-
max_per_src_element
recentered_scores
=
src
-
max_per_src_element
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment