Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
d63eb9c9
Commit
d63eb9c9
authored
Nov 05, 2019
by
Miltos Allamanis
Browse files
Remaining flake8 formatting errors
parent
0c127881
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
9 deletions
+22
-9
test/test_logsumexp.py
test/test_logsumexp.py
+8
-3
test/test_softmax.py
test/test_softmax.py
+14
-6
No files found.
test/test_logsumexp.py
View file @
d63eb9c9
...
@@ -9,15 +9,20 @@ from .utils import devices, tensor
...
@@ -9,15 +9,20 @@ from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES
=
{
torch
.
float32
,
torch
.
float64
}
SUPPORTED_FLOAT_DTYPES
=
{
torch
.
float32
,
torch
.
float64
}
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
def
test_logsumexp
(
dtype
,
device
):
def
test_logsumexp
(
dtype
,
device
):
src
=
tensor
([
0.5
,
0
,
0.5
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
src
=
tensor
([
0.5
,
0
,
0.5
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
index
=
tensor
([
0
,
1
,
0
,
1
,
1
,
2
,
4
,
4
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
1
,
0
,
1
,
1
,
2
,
4
,
4
],
torch
.
long
,
device
)
out
=
scatter_logsumexp
(
src
,
index
)
out
=
scatter_logsumexp
(
src
,
index
)
idx0
=
torch
.
logsumexp
(
torch
.
tensor
([
0.5
,
0.5
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
idx0
=
torch
.
logsumexp
(
idx1
=
torch
.
logsumexp
(
torch
.
tensor
([
0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
torch
.
tensor
([
0.5
,
0.5
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
idx1
=
torch
.
logsumexp
(
torch
.
tensor
([
0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
idx2
=
7
# Single element
idx2
=
7
# Single element
idx3
=
torch
.
finfo
(
dtype
).
min
# Empty index, returns yield value
idx3
=
torch
.
finfo
(
dtype
).
min
# Empty index, returns yield value
idx4
=
-
1
# logsumexp with -inf is the identity
idx4
=
-
1
# logsumexp with -inf is the identity
...
...
test/test_softmax.py
View file @
d63eb9c9
...
@@ -10,16 +10,20 @@ from .utils import devices, tensor
...
@@ -10,16 +10,20 @@ from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES
=
{
torch
.
float32
,
torch
.
float64
}
SUPPORTED_FLOAT_DTYPES
=
{
torch
.
float32
,
torch
.
float64
}
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
def
test_log_softmax
(
dtype
,
device
):
def
test_log_softmax
(
dtype
,
device
):
src
=
tensor
([
0.25
,
0
,
0.25
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
src
=
tensor
([
0.25
,
0
,
0.25
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
index
=
tensor
([
0
,
1
,
0
,
1
,
1
,
2
,
4
,
4
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
1
,
0
,
1
,
1
,
2
,
4
,
4
],
torch
.
long
,
device
)
out
=
scatter_log_softmax
(
src
,
index
)
out
=
scatter_log_softmax
(
src
,
index
)
# Expected results per index
# Expected results per index
idx0
=
[
np
.
log
(
0.5
),
np
.
log
(
0.5
)]
idx0
=
[
np
.
log
(
0.5
),
np
.
log
(
0.5
)]
idx1
=
torch
.
log_softmax
(
torch
.
tensor
([
0.0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
idx1
=
torch
.
log_softmax
(
torch
.
tensor
([
0.0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
idx2
=
0.0
# Single element, has logprob=0
idx2
=
0.0
# Single element, has logprob=0
# index=3 is empty. Should not matter.
# index=3 is empty. Should not matter.
idx4
=
[
0.0
,
float
(
'-inf'
)]
# log_softmax with -inf preserves the -inf
idx4
=
[
0.0
,
float
(
'-inf'
)]
# log_softmax with -inf preserves the -inf
...
@@ -31,16 +35,20 @@ def test_log_softmax(dtype, device):
...
@@ -31,16 +35,20 @@ def test_log_softmax(dtype, device):
)
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
SUPPORTED_FLOAT_DTYPES
,
devices
))
def
test_softmax
(
dtype
,
device
):
def
test_softmax
(
dtype
,
device
):
src
=
tensor
([
0.25
,
0
,
0.25
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
src
=
tensor
([
0.25
,
0
,
0.25
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)],
dtype
,
device
)
index
=
tensor
([
0
,
1
,
0
,
1
,
1
,
2
,
4
,
4
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
1
,
0
,
1
,
1
,
2
,
4
,
4
],
torch
.
long
,
device
)
out
=
scatter_softmax
(
src
,
index
)
out
=
scatter_softmax
(
src
,
index
)
# Expected results per index
# Expected results per index
idx0
=
[
0.5
,
0.5
]
idx0
=
[
0.5
,
0.5
]
idx1
=
torch
.
softmax
(
torch
.
tensor
([
0.0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
idx1
=
torch
.
softmax
(
torch
.
tensor
([
0.0
,
-
2.1
,
3.2
],
dtype
=
dtype
),
dim
=-
1
).
tolist
()
idx2
=
1
# Single element, has prob=1
idx2
=
1
# Single element, has prob=1
# index=3 is empty. Should not matter.
# index=3 is empty. Should not matter.
idx4
=
[
1.0
,
0.0
]
# softmax with -inf yields zero probability
idx4
=
[
1.0
,
0.0
]
# softmax with -inf yields zero probability
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment