test_softmax.py 1.61 KB
Newer Older
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_scatter import scatter_log_softmax, scatter_softmax
3
4


rusty1s's avatar
rusty1s committed
5
6
def test_softmax():
    src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
rusty1s's avatar
rusty1s committed
7
    src.requires_grad_()
rusty1s's avatar
rusty1s committed
8
    index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
9
10
11

    out = scatter_softmax(src, index)

rusty1s's avatar
rusty1s committed
12
13
14
15
    out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1)
    out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
    out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1)
    out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1)
16
17
18

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
rusty1s's avatar
rusty1s committed
19
    ], dim=0)
20
21
22

    assert torch.allclose(out, expected)

rusty1s's avatar
rusty1s committed
23
24
    out.backward(torch.randn_like(out))

25
26
27
    jit = torch.jit.script(scatter_softmax)
    assert jit(src, index).tolist() == out.tolist()

28

rusty1s's avatar
rusty1s committed
29
30
def test_log_softmax():
    src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
rusty1s's avatar
rusty1s committed
31
    src.requires_grad_()
rusty1s's avatar
rusty1s committed
32
    index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
33
34
35

    out = scatter_log_softmax(src, index)

rusty1s's avatar
rusty1s committed
36
37
38
39
    out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1)
    out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
    out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1)
    out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1)
40
41
42

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
rusty1s's avatar
rusty1s committed
43
    ], dim=0)
44
45

    assert torch.allclose(out, expected)
rusty1s's avatar
rusty1s committed
46
47

    out.backward(torch.randn_like(out))
48
49
50

    jit = torch.jit.script(scatter_log_softmax)
    assert jit(src, index).tolist() == out.tolist()