test_softmax.py 1.29 KB
Newer Older
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_scatter import scatter_log_softmax, scatter_softmax
3
4


rusty1s's avatar
rusty1s committed
5
6
7
def test_softmax():
    src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
    index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
8
9
10

    out = scatter_softmax(src, index)

rusty1s's avatar
rusty1s committed
11
12
13
14
    out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1)
    out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
    out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1)
    out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1)
15
16
17

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
rusty1s's avatar
rusty1s committed
18
    ], dim=0)
19
20
21
22

    assert torch.allclose(out, expected)


rusty1s's avatar
rusty1s committed
23
24
25
def test_log_softmax():
    src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
    index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
26
27
28

    out = scatter_log_softmax(src, index)

rusty1s's avatar
rusty1s committed
29
30
31
32
    out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1)
    out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
    out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1)
    out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1)
33
34
35

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
rusty1s's avatar
rusty1s committed
36
    ], dim=0)
37
38

    assert torch.allclose(out, expected)