test_logsumexp.py 573 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
import torch
from torch_scatter import scatter_logsumexp


def test_logsumexp():
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
12
13
    inputs = torch.tensor([
        0.5, 0.5, 0.0, -2.1, 3.2, 7.0, -1.0, -100.0,
        float('-inf'),
        float('-inf'), 0.0
    ])
    inputs.requires_grad_()
    index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4, 5, 6, 6])
    splits = [2, 3, 1, 0, 2, 1, 2]
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
    outputs = scatter_logsumexp(inputs, index)
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
    for src, out in zip(inputs.split(splits), outputs.unbind()):
        assert out.tolist() == torch.logsumexp(src, dim=0).tolist()
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
    outputs.backward(torch.randn_like(outputs))