test_logsumexp.py 679 Bytes
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from torch_scatter import scatter_logsumexp


def test_logsumexp():
    inputs = torch.tensor([
        0.5, 0.5, 0.0, -2.1, 3.2, 7.0, -1.0, -100.0,
        float('-inf'),
        float('-inf'), 0.0
    ])
    inputs.requires_grad_()
    index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4, 5, 6, 6])
    splits = [2, 3, 1, 0, 2, 1, 2]

    outputs = scatter_logsumexp(inputs, index)

    for src, out in zip(inputs.split(splits), outputs.unbind()):
        assert out.tolist() == torch.logsumexp(src, dim=0).tolist()

    outputs.backward(torch.randn_like(outputs))

    jit = torch.jit.script(scatter_logsumexp)
    assert jit(inputs, index).tolist() == outputs.tolist()