test_softmax.py 2.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from itertools import product

import pytest
import torch
from torch_scatter.composite import scatter_log_softmax, scatter_softmax

from test.utils import devices, tensor, grad_dtypes


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_softmax(dtype, device):
    src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
    index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)

    out = scatter_softmax(src, index)

    out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
    out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
    out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1)
    out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
                         dim=-1)

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
25
    ], dim=0).to(device)
26
27
28
29

    assert torch.allclose(out, expected)


30
31
32
33
34
35
36
37
38
39
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_softmax_broadcasting(dtype, device):
    src = torch.randn(10, 5, dtype=dtype, device=device)
    index = tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)

    out = scatter_softmax(src, index, dim=0).view(5, 2, 5)
    out = out.sum(dim=1)
    assert torch.allclose(out, torch.ones_like(out))


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_log_softmax(dtype, device):
    src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
    index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)

    out = scatter_log_softmax(src, index)

    out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
    out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
    out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1)
    out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
                             dim=-1)

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
55
    ], dim=0).to(device)
56
57

    assert torch.allclose(out, expected)