Commit 9c7af8df authored by Miltos Allamanis's avatar Miltos Allamanis
Browse files

Move epsilon to an argument.

parent dd50d35f
...@@ -2,9 +2,8 @@ import torch ...@@ -2,9 +2,8 @@ import torch
from . import scatter_add, scatter_max from . import scatter_add, scatter_max
EPSILON = 1e-16
def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None): def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16):
if not torch.is_floating_point(src): if not torch.is_floating_point(src):
raise ValueError('logsumexp can be computed over tensors floating point data types.') raise ValueError('logsumexp can be computed over tensors floating point data types.')
...@@ -25,9 +24,10 @@ def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=N ...@@ -25,9 +24,10 @@ def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=N
dim_size=dim_size, dim_size=dim_size,
fill_value=fill_value, fill_value=fill_value,
) )
return torch.log(sum_per_index + EPSILON) + max_value_per_index, recentered_scores return torch.log(sum_per_index + epsilon) + max_value_per_index, recentered_scores
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16):
r""" r"""
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis indices specified in the :attr:`index` tensor along a given axis
...@@ -63,4 +63,4 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No ...@@ -63,4 +63,4 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value)[0] return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value, epsilon=epsilon)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment