softmax.py 3.14 KB
Newer Older
1
2
import torch

3
from torch_scatter import scatter_add, scatter_max
4
from torch_scatter.utils.gen import broadcast
5

6

7
def scatter_softmax(src, index, dim=-1, eps=1e-12):
8
    r"""
9
10
    Softmax operation over all values in :attr:`src` tensor that share indices
    specified in the :attr:`index` tensor along a given axis :attr:`dim`.
11
12
13
14

    For one-dimensional tensors, the operation computes

    .. math::
15
16
        \mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i =
        \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
17

18
    where :math:`\sum_j` is over :math:`j` such that
19
20
21
22
23
24
25
    :math:`\mathrm{index}_j = i`.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements to scatter.
        dim (int, optional): The axis along which to index.
            (default: :obj:`-1`)
26
27
        eps (float, optional): Small value to ensure numerical stability.
            (default: :obj:`1e-12`)
28
29
30

    :rtype: :class:`Tensor`
    """
31
    if not torch.is_floating_point(src):
32
33
        raise ValueError('`scatter_softmax` can only be computed over tensors '
                         'with floating point data types.')
34

35
    src, index = broadcast(src, index, dim)
36
    max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
37
    max_per_src_element = max_value_per_index.gather(dim, index)
38

39
    recentered_scores = src - max_per_src_element
40
    recentered_scores_exp = recentered_scores.exp()
41

42
43
    sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim)
    normalizing_constants = (sum_per_index + eps).gather(dim, index)
44

45
    return recentered_scores_exp / normalizing_constants
46
47


48
def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
49
    r"""
50
    Log-softmax operation over all values in :attr:`src` tensor that share
51
    indices specified in the :attr:`index` tensor along a given axis
52
    :attr:`dim`.
53
54
55
56

    For one-dimensional tensors, the operation computes

    .. math::
57
58
59
        \mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
        \log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
        \right)
60

61
    where :math:`\sum_j` is over :math:`j` such that
62
63
64
65
66
67
68
    :math:`\mathrm{index}_j = i`.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements to scatter.
        dim (int, optional): The axis along which to index.
            (default: :obj:`-1`)
69
70
        eps (float, optional): Small value to ensure numerical stability.
            (default: :obj:`1e-12`)
71
72
73

    :rtype: :class:`Tensor`
    """
74
    if not torch.is_floating_point(src):
75
        raise ValueError('`scatter_log_softmax` can only be computed over '
76
                         'tensors with floating point data types.')
77

78
    src, index = broadcast(src, index, dim)
79
    max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
80
81
82
    max_per_src_element = max_value_per_index.gather(dim, index)

    recentered_scores = src - max_per_src_element
83
84
85
86
87
88
89

    sum_per_index = scatter_add(src=recentered_scores.exp(), index=index,
                                dim=dim)

    normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index)

    return recentered_scores - normalizing_constants