softmax.py 3.01 KB
Newer Older
1
2
import torch

3
from torch_scatter import scatter_add, scatter_max
4

5

6
def scatter_softmax(src, index, dim=-1, eps=1e-12):
7
    r"""
8
9
    Softmax operation over all values in :attr:`src` tensor that share indices
    specified in the :attr:`index` tensor along a given axis :attr:`dim`.
10
11
12
13

    For one-dimensional tensors, the operation computes

    .. math::
14
15
        \mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i =
        \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
16

17
    where :math:`\sum_j` is over :math:`j` such that
18
19
20
21
22
23
24
    :math:`\mathrm{index}_j = i`.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements to scatter.
        dim (int, optional): The axis along which to index.
            (default: :obj:`-1`)
25
26
        eps (float, optional): Small value to ensure numerical stability.
            (default: :obj:`1e-12`)
27
28
29

    :rtype: :class:`Tensor`
    """
30
    if not torch.is_floating_point(src):
31
32
        raise ValueError('`scatter_softmax` can only be computed over tensors '
                         'with floating point data types.')
33

34
    max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
35
    max_per_src_element = max_value_per_index.gather(dim, index)
36

37
    recentered_scores = src - max_per_src_element
38
    recentered_scores_exp = recentered_scores.exp()
39

40
41
    sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim)
    normalizing_constants = (sum_per_index + eps).gather(dim, index)
42

43
    return recentered_scores_exp / normalizing_constants
44
45


46
def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
47
    r"""
48
    Log-softmax operation over all values in :attr:`src` tensor that share
49
    indices specified in the :attr:`index` tensor along a given axis
50
    :attr:`dim`.
51
52
53
54

    For one-dimensional tensors, the operation computes

    .. math::
55
56
57
        \mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
        \log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
        \right)
58

59
    where :math:`\sum_j` is over :math:`j` such that
60
61
62
63
64
65
66
    :math:`\mathrm{index}_j = i`.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements to scatter.
        dim (int, optional): The axis along which to index.
            (default: :obj:`-1`)
67
68
        eps (float, optional): Small value to ensure numerical stability.
            (default: :obj:`1e-12`)
69
70
71

    :rtype: :class:`Tensor`
    """
72
    if not torch.is_floating_point(src):
73
        raise ValueError('`scatter_log_softmax` can only be computed over '
74
                         'tensors with floating point data types.')
75

76
    max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
77
78
79
    max_per_src_element = max_value_per_index.gather(dim, index)

    recentered_scores = src - max_per_src_element
80
81
82
83
84
85
86

    sum_per_index = scatter_add(src=recentered_scores.exp(), index=index,
                                dim=dim)

    normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index)

    return recentered_scores - normalizing_constants