test_softmax.py 1.42 KB
Newer Older
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_scatter import scatter_log_softmax, scatter_softmax
3
4


rusty1s's avatar
rusty1s committed
5
6
def test_softmax():
    src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
rusty1s's avatar
rusty1s committed
7
    src.requires_grad_()
rusty1s's avatar
rusty1s committed
8
    index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
9
10
11

    out = scatter_softmax(src, index)

rusty1s's avatar
rusty1s committed
12
13
14
15
    out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1)
    out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
    out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1)
    out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1)
16
17
18

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
rusty1s's avatar
rusty1s committed
19
    ], dim=0)
20
21
22

    assert torch.allclose(out, expected)

rusty1s's avatar
rusty1s committed
23
24
    out.backward(torch.randn_like(out))

25

rusty1s's avatar
rusty1s committed
26
27
def test_log_softmax():
    src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
rusty1s's avatar
rusty1s committed
28
    src.requires_grad_()
rusty1s's avatar
rusty1s committed
29
    index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
30
31
32

    out = scatter_log_softmax(src, index)

rusty1s's avatar
rusty1s committed
33
34
35
36
    out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1)
    out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
    out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1)
    out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1)
37
38
39

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
rusty1s's avatar
rusty1s committed
40
    ], dim=0)
41
42

    assert torch.allclose(out, expected)
rusty1s's avatar
rusty1s committed
43
44

    out.backward(torch.randn_like(out))