test_softmax.py 1.11 KB
Newer Older
1
2
3
4
5
6
7
import sys

import backend as F

import dgl
import pytest
import torch
8
from dgl.sparse import from_coo, softmax
9
10
11
12


@pytest.mark.parametrize("val_D", [None, 2])
@pytest.mark.parametrize("csr", [True, False])
13
14
@pytest.mark.parametrize("dim", [0, 1])
def test_softmax(val_D, csr, dim):
15
16
17
18
19
20
21
22
23
24
    dev = F.ctx()
    row = torch.tensor([0, 0, 1, 1]).to(dev)
    col = torch.tensor([0, 2, 1, 2]).to(dev)
    nnz = len(row)
    if val_D is None:
        val = torch.randn(nnz).to(dev)
    else:
        val = torch.randn(nnz, val_D).to(dev)

    val_sparse = val.clone().requires_grad_()
25
    A = from_coo(row, col, val_sparse)
26
27
28
29
30

    if csr:
        # Test CSR
        A.csr()

31
32
33
34
35
    A_max = softmax(A, dim)
    if dim == 1:
        g = dgl.graph((col, row), num_nodes=max(A.shape))
    else:
        g = dgl.graph((row, col), num_nodes=max(A.shape))
36
37
    val_g = val.clone().requires_grad_()
    score = dgl.nn.functional.edge_softmax(g, val_g)
czkkkkkk's avatar
czkkkkkk committed
38
    assert torch.allclose(A_max.val, score, atol=1e-05)
39
40
41
42

    grad = torch.randn_like(score).to(dev)
    A_max.val.backward(grad)
    score.backward(grad)
czkkkkkk's avatar
czkkkkkk committed
43
    assert torch.allclose(A.val.grad, val_g.grad, atol=1e-05)