softmax.py 1.64 KB
Newer Older
1
2
import torch

rusty1s's avatar
rusty1s committed
3
from torch_scatter import scatter_sum, scatter_max
4
from torch_scatter.utils import broadcast
5

6

rusty1s's avatar
rusty1s committed
7
8
9
@torch.jit.script
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                    eps: float = 1e-12) -> torch.Tensor:
10
    if not torch.is_floating_point(src):
11
12
        raise ValueError('`scatter_softmax` can only be computed over tensors '
                         'with floating point data types.')
13

rusty1s's avatar
rusty1s committed
14
15
16
    index = broadcast(index, src, dim)

    max_value_per_index = scatter_max(src, index, dim=dim)[0]
17
    max_per_src_element = max_value_per_index.gather(dim, index)
18

19
    recentered_scores = src - max_per_src_element
rusty1s's avatar
rusty1s committed
20
    recentered_scores_exp = recentered_scores.exp_()
21

rusty1s's avatar
rusty1s committed
22
23
    sum_per_index = scatter_sum(recentered_scores_exp, index, dim)
    normalizing_constants = sum_per_index.add_(eps).gather(dim, index)
24

rusty1s's avatar
rusty1s committed
25
    return recentered_scores_exp.div_(normalizing_constants)
26

27

rusty1s's avatar
rusty1s committed
28
29
30
@torch.jit.script
def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                        eps: float = 1e-12) -> torch.Tensor:
31
    if not torch.is_floating_point(src):
32
        raise ValueError('`scatter_log_softmax` can only be computed over '
33
                         'tensors with floating point data types.')
34

rusty1s's avatar
rusty1s committed
35
36
37
    index = broadcast(index, src, dim)

    max_value_per_index = scatter_max(src, index, dim=dim)[0]
38
39
40
    max_per_src_element = max_value_per_index.gather(dim, index)

    recentered_scores = src - max_per_src_element
41

rusty1s's avatar
rusty1s committed
42
43
    sum_per_index = scatter_sum(recentered_scores.exp(), index, dim)
    normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index)
44

rusty1s's avatar
rusty1s committed
45
    return recentered_scores.sub_(normalizing_constants)