Commit 7b14c671 authored by Miltos Allamanis's avatar Miltos Allamanis
Browse files

Bug fixes, testing and other minor edits.

* `log_softmax` has now stand-alone to save one operation (and fix a bug).
* `softmax` is implemented in a similar stand-alone way.
* Address some PR comments.
parent 0ef92602
from itertools import product
import torch
import pytest
from torch_scatter import scatter_max, scatter_logsumexp
from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
def test_logsumexp(dtype, device):
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_logsumexp(src, index)
idx0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1).tolist()
idx1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
idx2 = 7 # Single element
idx3 = torch.finfo(dtype).min # Empty index, returns yield value
idx4 = -1 # logsumexp with -inf is the identity
assert out.tolist() == [idx0, idx1, idx2, idx3, idx4]
from itertools import product
import numpy as np
import pytest
import torch
from torch_scatter.composite import scatter_log_softmax, scatter_softmax
from .utils import devices, tensor
SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64}
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
def test_log_softmax(dtype, device):
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_log_softmax(src, index)
# Expected results per index
idx0 = [np.log(0.5), np.log(0.5)]
idx1 = torch.log_softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
idx2 = 0.0 # Single element, has logprob=0
# index=3 is empty. Should not matter.
idx4 = [0.0, float('-inf')] # log_softmax with -inf preserves the -inf
np.testing.assert_allclose(
out.tolist(),
[idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]],
rtol=1e-05, atol=1e-10
)
@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices))
def test_softmax(dtype, device):
src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_softmax(src, index)
# Expected results per index
idx0 = [0.5, 0.5]
idx1 = torch.softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist()
idx2 = 1 # Single element, has prob=1
# index=3 is empty. Should not matter.
idx4 = [1.0, 0.0] # softmax with -inf yields zero probability
np.testing.assert_allclose(
out.tolist(),
[idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]],
rtol=1e-05, atol=1e-10
)
\ No newline at end of file
import torch
from torch_scatter.logsumexp import _scatter_logsumexp
from torch_scatter import scatter_add, scatter_max
def scatter_log_softmax(src, index, dim=-1, dim_size=None):
r"""
......@@ -12,7 +12,8 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) = \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
\mathrm{out}_i = softmax(\mathrm{src}_i) =
\mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
......@@ -40,11 +41,26 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
:rtype: :class:`Tensor`
"""
per_index_logsumexp, recentered_src = _scatter_logsumexp(src, index, dim=dim, dim_size=dim_size)
return recentered_src - per_index_logsumexp.gather(dim, index)
if not torch.is_floating_point(src):
raise ValueError('log_softmax can be computed only over tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size)
max_per_src_element = max_value_per_index.gather(dim, index)
def scatter_softmax(src, index, dim=-1, dim_size=None):
recentered_scores = src - max_per_src_element
sum_per_index = scatter_add(
src=recentered_scores.exp(),
index=index,
dim=dim,
dim_size=dim_size
)
log_normalizing_constants = sum_per_index.log().gather(dim, index)
return recentered_scores - log_normalizing_constants
def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
r"""
Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
......@@ -54,7 +70,8 @@ def scatter_softmax(src, index, dim=-1, dim_size=None):
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = softmax(\mathrm{src}_i) = \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
\mathrm{out}_i = softmax(\mathrm{src}_i) =
\frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
......@@ -82,4 +99,20 @@ def scatter_softmax(src, index, dim=-1, dim_size=None):
:rtype: :class:`Tensor`
"""
return scatter_log_softmax(src, index, dim, dim_size).exp()
if not torch.is_floating_point(src):
raise ValueError('softmax can be computed only over tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
exped_recentered_scores = recentered_scores.exp()
sum_per_index = scatter_add(
src=exped_recentered_scores,
index=index,
dim=dim,
dim_size=dim_size
)
normalizing_constant = (sum_per_index + epsilon).gather(dim, index)
return exped_recentered_scores / normalizing_constant
......@@ -3,30 +3,6 @@ import torch
from . import scatter_add, scatter_max
def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16):
if not torch.is_floating_point(src):
raise ValueError('logsumexp can be computed over tensors floating point data types.')
if fill_value is None:
fill_value = torch.finfo(src.dtype).min
dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size
max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
sum_per_index = scatter_add(
src=recentered_scores.exp(),
index=index,
dim=dim,
out=(src - max_per_src_element).exp() if out is not None else None,
dim_size=dim_size,
fill_value=fill_value,
)
return torch.log(sum_per_index + epsilon) + max_value_per_index, recentered_scores
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16):
r"""
Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the
......@@ -63,4 +39,24 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No
:rtype: :class:`Tensor`
"""
return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value, epsilon=epsilon)[0]
if not torch.is_floating_point(src):
raise ValueError('logsumexp can be computed over tensors with floating point data types.')
if fill_value is None:
fill_value = torch.finfo(src.dtype).min
dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size
max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
sum_per_index = scatter_add(
src=recentered_scores.exp(),
index=index,
dim=dim,
out=(out - max_per_src_element).exp() if out is not None else None,
dim_size=dim_size,
fill_value=0,
)
return torch.log(sum_per_index + epsilon) + max_value_per_index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment