Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
9c7af8df
"...text-generation-inference.git" did not exist on "2fe5e1b30e5bd490887e5b3ffa9fcaab8cbe1304"
Commit
9c7af8df
authored
Nov 04, 2019
by
Miltos Allamanis
Browse files
Move epsilon to an argument.
parent
dd50d35f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
torch_scatter/logsumexp.py
torch_scatter/logsumexp.py
+5
-5
No files found.
torch_scatter/logsumexp.py
View file @
9c7af8df
...
@@ -2,9 +2,8 @@ import torch
...
@@ -2,9 +2,8 @@ import torch
from
.
import
scatter_add
,
scatter_max
from
.
import
scatter_add
,
scatter_max
EPSILON
=
1e-16
def
_scatter_logsumexp
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
None
):
def
_scatter_logsumexp
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
None
,
epsilon
=
1e-16
):
if
not
torch
.
is_floating_point
(
src
):
if
not
torch
.
is_floating_point
(
src
):
raise
ValueError
(
'logsumexp can be computed over tensors floating point data types.'
)
raise
ValueError
(
'logsumexp can be computed over tensors floating point data types.'
)
...
@@ -25,9 +24,10 @@ def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=N
...
@@ -25,9 +24,10 @@ def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=N
dim_size
=
dim_size
,
dim_size
=
dim_size
,
fill_value
=
fill_value
,
fill_value
=
fill_value
,
)
)
return
torch
.
log
(
sum_per_index
+
EPSILON
)
+
max_value_per_index
,
recentered_scores
return
torch
.
log
(
sum_per_index
+
epsilon
)
+
max_value_per_index
,
recentered_scores
def
scatter_logsumexp
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
None
):
def
scatter_logsumexp
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
None
,
epsilon
=
1e-16
):
r
"""
r
"""
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
indices specified in the :attr:`index` tensor along a given axis
...
@@ -63,4 +63,4 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
...
@@ -63,4 +63,4 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
:rtype: :class:`Tensor`
:rtype: :class:`Tensor`
"""
"""
return
_scatter_logsumexp
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)[
0
]
return
_scatter_logsumexp
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
,
epsilon
=
epsilon
)[
0
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment